Class FastLoadedDiceRollerDiscreteSampler
- java.lang.Object
-
- org.apache.commons.rng.sampling.distribution.FastLoadedDiceRollerDiscreteSampler
-
- All Implemented Interfaces:
DiscreteSampler
,SharedStateDiscreteSampler
,SharedStateSampler<SharedStateDiscreteSampler>
- Direct Known Subclasses:
FastLoadedDiceRollerDiscreteSampler.FixedValueDiscreteSampler
,FastLoadedDiceRollerDiscreteSampler.FLDRSampler
public abstract class FastLoadedDiceRollerDiscreteSampler extends java.lang.Object implements SharedStateDiscreteSampler
Distribution sampler that uses the Fast Loaded Dice Roller (FLDR). It can be used to sample fromn
values each with an associated relative weight. If all unique items are assigned the same weight it is more efficient to use theDiscreteUniformSampler
.Given a list
L
ofn
positive numbers, whereL[i]
represents the relative weight of thei
th side, FLDR returns integeri
with relative probabilityL[i]
.FLDR produces exact samples from the specified probability distribution.
- For integer weights, the probability of returning
i
is precisely equal to the rational numberL[i] / m
, wherem
is the sum ofL
. - For floating-points weights, each weight
L[i]
is converted to the corresponding rational numberp[i] / q[i]
wherep[i]
is a positive integer andq[i]
is a power of 2. The rational weights are then normalized (exactly) to sum to unity.
Note that if exact samples are not required then an alternative sampler that ignores very small relative weights may have improved sampling performance.
This implementation is based on the algorithm in:
Feras A. Saad, Cameron E. Freer, Martin C. Rinard, and Vikash K. Mansinghka. The Fast Loaded Dice Roller: A Near-Optimal Exact Sampler for Discrete Probability Distributions. In AISTATS 2020: Proceedings of the 23rd International Conference on Artificial Intelligence and Statistics, Proceedings of Machine Learning Research 108, Palermo, Sicily, Italy, 2020.
Sampling uses
UniformRandomProvider.nextInt()
as the source of random bits.
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description private static class
FastLoadedDiceRollerDiscreteSampler.FixedValueDiscreteSampler
Class to handle the edge case of observations in only one category.private static class
FastLoadedDiceRollerDiscreteSampler.FLDRSampler
Class to implement the FLDR sample algorithm.
-
Field Summary
Fields Modifier and Type Field Description private static long
MANTISSA_MASK
Mask to extract the 52-bit mantissa from a long representation of a double.private static int
MANTISSA_SIZE
Size of the mantissa of a double.private static int
MAX_ARRAY_SIZE
The maximum size of an array.private static int
MAX_BIASED_EXPONENT
The maximum biased exponent for a finite double.private static java.math.BigInteger
MAX_LONG
BigInteger representation ofLong.MAX_VALUE
.private static int
MAX_OFFSET
The maximum offset that will avoid loss of bits for a left shift of a 53-bit value.private static int
NO_LABEL
Initial value for no leaf node label.private static java.lang.String
SAMPLER_NAME
Name of the sampler.
-
Constructor Summary
Constructors Constructor Description FastLoadedDiceRollerDiscreteSampler()
Package-private constructor.
-
Method Summary
All Methods Static Methods Instance Methods Abstract Methods Concrete Methods Modifier and Type Method Description (package private) static int
checkArraySize(long size)
Check the size is valid for a 1D array.private static int
checkWeightsNonZeroLength(double[] weights)
Check the weights have a non-zero length.private static void
convertToIntegers(double[] weights, long[] values, int[] exponents, int alpha)
Convert the floating-point weights to relative weights represented as integersvalue * 2^exponent
.private static FastLoadedDiceRollerDiscreteSampler
createSampler(UniformRandomProvider rng, long[] frequencies, int[] offsets, int[] indices, java.math.BigInteger m)
Creates the sampler.private static FastLoadedDiceRollerDiscreteSampler
createSampler(UniformRandomProvider rng, long[] frequencies, int[] indices, long m)
Creates the sampler.private static void
filterWeights(long[] values, int[] exponents, int alpha, int maxExponent)
Filters small weights using thealpha
parameter.(package private) static int
indexOfNonZero(long[] frequencies)
Find the index of the first non-zero frequency.private static int[]
indicesOfNonZero(long[] values)
Create the indices of non-zero values.static FastLoadedDiceRollerDiscreteSampler
of(UniformRandomProvider rng, double[] weights)
Creates a sampler.static FastLoadedDiceRollerDiscreteSampler
of(UniformRandomProvider rng, double[] weights, int alpha)
Creates a sampler.static FastLoadedDiceRollerDiscreteSampler
of(UniformRandomProvider rng, long[] frequencies)
Creates a sampler.private static void
scaleWeights(long[] values, int[] exponents)
Scale the weights represented as integersvalue * 2^exponent
to use a minimum exponent of zero.private static long
sum(long[] frequencies)
Sum the frequencies.private static java.math.BigInteger
sum(long[] values, int[] exponents, int[] indices)
Sum the integers at the specified indices.private static boolean
testBit(long value, int offset, int n)
Test the logical bit of the shifted integer representation.private static java.math.BigInteger
toBigInteger(long value, int offset)
Convert the value and left shift offset to a BigInteger.abstract FastLoadedDiceRollerDiscreteSampler
withUniformRandomProvider(UniformRandomProvider rng)
Create a new instance of the sampler with the same underlying state using the given uniform random provider as the source of randomness.-
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
-
Methods inherited from interface org.apache.commons.rng.sampling.distribution.DiscreteSampler
sample, samples, samples
-
-
-
-
Field Detail
-
MAX_ARRAY_SIZE
private static final int MAX_ARRAY_SIZE
The maximum size of an array.This value is taken from the limit in Open JDK 8
java.util.ArrayList
. It allows VMs to reserve some header words in an array.- See Also:
- Constant Field Values
-
MAX_BIASED_EXPONENT
private static final int MAX_BIASED_EXPONENT
The maximum biased exponent for a finite double. This is offset by 1023 fromMath.getExponent(Double.MAX_VALUE)
.- See Also:
- Constant Field Values
-
MANTISSA_SIZE
private static final int MANTISSA_SIZE
Size of the mantissa of a double. Equal to 52 bits.- See Also:
- Constant Field Values
-
MANTISSA_MASK
private static final long MANTISSA_MASK
Mask to extract the 52-bit mantissa from a long representation of a double.- See Also:
- Constant Field Values
-
MAX_LONG
private static final java.math.BigInteger MAX_LONG
BigInteger representation ofLong.MAX_VALUE
.
-
MAX_OFFSET
private static final int MAX_OFFSET
The maximum offset that will avoid loss of bits for a left shift of a 53-bit value. The value will remain positive for any shift<=
this value.- See Also:
- Constant Field Values
-
NO_LABEL
private static final int NO_LABEL
Initial value for no leaf node label.- See Also:
- Constant Field Values
-
SAMPLER_NAME
private static final java.lang.String SAMPLER_NAME
Name of the sampler.- See Also:
- Constant Field Values
-
-
Method Detail
-
withUniformRandomProvider
public abstract FastLoadedDiceRollerDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng)
Create a new instance of the sampler with the same underlying state using the given uniform random provider as the source of randomness.- Specified by:
withUniformRandomProvider
in interfaceSharedStateSampler<SharedStateDiscreteSampler>
- Parameters:
rng
- Generator of uniformly distributed random numbers.- Returns:
- the sampler
-
of
public static FastLoadedDiceRollerDiscreteSampler of(UniformRandomProvider rng, long[] frequencies)
Creates a sampler.Note: The discrete distribution generating (DDG) tree requires
(n + 1) * k
entries wheren
is the number of categories,k == ceil(log2(m))
andm
is the sum of the observed frequencies. An exception is raised if this cannot be allocated as a single array.For reference the sum is limited to
Long.MAX_VALUE
and the valuek
to 63. The number of categories is limited to approximately((2^31 - 1) / k) = 34,087,042
when the sum of frequencies is large enough to create k=63.- Parameters:
rng
- Generator of uniformly distributed random numbers.frequencies
- Observed frequencies of the discrete distribution.- Returns:
- the sampler
- Throws:
java.lang.IllegalArgumentException
- iffrequencies
is null or empty, a frequency is negative, the sum of all frequencies is either zero or aboveLong.MAX_VALUE
, or the size of the discrete distribution generating tree is too large.
-
of
public static FastLoadedDiceRollerDiscreteSampler of(UniformRandomProvider rng, double[] weights)
Creates a sampler.Weights are converted to rational numbers
p / q
whereq
is a power of 2. The numeratorsp
are scaled to use a common denominator before summing.All weights are used to create the sampler. Weights with a small magnitude relative to the largest weight can be excluded using the constructor method with the relative magnitude parameter
alpha
(seeof(UniformRandomProvider, double[], int)
).- Parameters:
rng
- Generator of uniformly distributed random numbers.weights
- Weights of the discrete distribution.- Returns:
- the sampler
- Throws:
java.lang.IllegalArgumentException
- ifweights
is null or empty, a weight is negative, infinite orNaN
, the sum of all weights is zero, or the size of the discrete distribution generating tree is too large.- See Also:
of(UniformRandomProvider, double[], int)
-
of
public static FastLoadedDiceRollerDiscreteSampler of(UniformRandomProvider rng, double[] weights, int alpha)
Creates a sampler.Weights are converted to rational numbers
p / q
whereq
is a power of 2. The numeratorsp
are scaled to use a common denominator before summing.Note: The discrete distribution generating (DDG) tree requires
(n + 1) * k
entries wheren
is the number of categories,k == ceil(log2(m))
andm
is the sum of the weight numeratorsq
. An exception is raised if this cannot be allocated as a single array.For reference the value
k
is equal to or greater than the ratio of the largest to the smallest weight expressed as a power of 2. ForDouble.MAX_VALUE / Double.MIN_VALUE
this is ~2098. The valuek
increases with the sum of the weight numerators. A number of weights in excess of 1,000,000 with values equal toDouble.MAX_VALUE
would be required to raise an exception when the minimum weight isDouble.MIN_VALUE
.Weights with a small magnitude relative to the largest weight can be excluded using the relative magnitude parameter
alpha
. This will set any weight to zero if the magnitude is approximately 2alpha smaller than the largest weight. This comparison is made using only the exponent of the input weights. Thealpha
parameter is ignored if not above zero. Note that a smallalpha
parameter will exclude more weights than a largealpha
parameter.The alpha parameter can be used to exclude categories that have a very low probability of occurrence and will improve the construction performance of the sampler. The effect on sampling performance depends on the relative weights of the excluded categories; typically a high
alpha
is used to exclude categories that would be visited with a very low probability and the sampling performance is unchanged.Implementation Note
This method creates a sampler with exact samples from the specified probability distribution. It is recommended to use this method:
- if the weights are computed, for example from a probability mass function; or
- if the weights sum to an infinite value.
If the weights are computed from empirical observations then it is recommended to use the factory method
accepting frequencies
. This requires the total number of observations to be representable as a long integer.Note that if all weights are scaled by a power of 2 to be integers, and each integer can be represented as a positive 64-bit long value, then the sampler created using this method will match the output from a sampler created with the scaled weights converted to long values for the factory method
accepting frequencies
. This assumes the sum of the integer values does not overflow.It should be noted that the conversion of weights to rational numbers has a performance overhead during construction (sampling performance is not affected). This may be avoided by first converting them to integer values that can be summed without overflow. For example by scaling values by
2^62 / sum
and converting to long by casting or rounding.This approach may increase the efficiency of construction. The resulting sampler may no longer produce exact samples from the distribution. In particular any weights with a converted frequency of zero cannot be sampled.
- Parameters:
rng
- Generator of uniformly distributed random numbers.weights
- Weights of the discrete distribution.alpha
- Alpha parameter.- Returns:
- the sampler
- Throws:
java.lang.IllegalArgumentException
- ifweights
is null or empty, a weight is negative, infinite orNaN
, the sum of all weights is zero, or the size of the discrete distribution generating tree is too large.- See Also:
of(UniformRandomProvider, long[])
-
sum
private static long sum(long[] frequencies)
Sum the frequencies.- Parameters:
frequencies
- Frequencies.- Returns:
- the sum
- Throws:
java.lang.IllegalArgumentException
- iffrequencies
is null or empty, a frequency is negative, or the sum of all frequencies is either zero or aboveLong.MAX_VALUE
-
convertToIntegers
private static void convertToIntegers(double[] weights, long[] values, int[] exponents, int alpha)
Convert the floating-point weights to relative weights represented as integersvalue * 2^exponent
. The relative weight as an integer is:BigInteger.valueOf(value).shiftLeft(exponent)
Note that the weights are created using a common power-of-2 scaling operation so the minimum exponent is zero.
A positive
alpha
parameter is used to set any weight to zero if the magnitude is approximately 2alpha smaller than the largest weight. This comparison is made using only the exponent of the input weights.- Parameters:
weights
- Weights of the discrete distribution.values
- Output floating-point mantissas converted to 53-bit integers.exponents
- Output power of 2 exponent.alpha
- Alpha parameter.- Throws:
java.lang.IllegalArgumentException
- if a weight is negative, infinite orNaN
, or the sum of all weights is zero.
-
filterWeights
private static void filterWeights(long[] values, int[] exponents, int alpha, int maxExponent)
Filters small weights using thealpha
parameter. A positivealpha
parameter is used to set any weight to zero if the magnitude is approximately 2alpha smaller than the largest weight. This comparison is made using only the exponent of the input weights.- Parameters:
values
- 53-bit values.exponents
- Power of 2 exponent.alpha
- Alpha parameter.maxExponent
- Maximum exponent.
-
scaleWeights
private static void scaleWeights(long[] values, int[] exponents)
Scale the weights represented as integersvalue * 2^exponent
to use a minimum exponent of zero. The values are scaled to remove any common trailing zeros in their representation. This ultimately reduces the size of the discrete distribution generating (DGG) tree.- Parameters:
values
- 53-bit values.exponents
- Power of 2 exponent.
-
sum
private static java.math.BigInteger sum(long[] values, int[] exponents, int[] indices)
Sum the integers at the specified indices. Integers are represented asvalue * 2^exponent
.- Parameters:
values
- 53-bit values.exponents
- Power of 2 exponent.indices
- Indices to sum.- Returns:
- the sum
-
toBigInteger
private static java.math.BigInteger toBigInteger(long value, int offset)
Convert the value and left shift offset to a BigInteger. It is assumed the value is at most 53-bits. This allows optimising the left shift if it is below 11 bits.- Parameters:
value
- 53-bit value.offset
- Left shift offset (must be positive).- Returns:
- the BigInteger
-
createSampler
private static FastLoadedDiceRollerDiscreteSampler createSampler(UniformRandomProvider rng, long[] frequencies, int[] indices, long m)
Creates the sampler.It is assumed the frequencies are all positive and the sum does not overflow.
- Parameters:
rng
- Generator of uniformly distributed random numbers.frequencies
- Observed frequencies of the discrete distribution.indices
- Indices of non-zero frequencies.m
- Sum of the frequencies.- Returns:
- the sampler
-
createSampler
private static FastLoadedDiceRollerDiscreteSampler createSampler(UniformRandomProvider rng, long[] frequencies, int[] offsets, int[] indices, java.math.BigInteger m)
Creates the sampler. Frequencies are represented as a 53-bit value with a left-shift offset.BigInteger.valueOf(value).shiftLeft(offset)
It is assumed the frequencies are all positive.
- Parameters:
rng
- Generator of uniformly distributed random numbers.frequencies
- Observed frequencies of the discrete distribution.offsets
- Left shift offsets (must be positive).indices
- Indices of non-zero frequencies.m
- Sum of the frequencies.- Returns:
- the sampler
-
testBit
private static boolean testBit(long value, int offset, int n)
Test the logical bit of the shifted integer representation. The value is assumed to have at most 53-bits of information. The offset is assumed to be positive. This is functionally equivalent to:BigInteger.valueOf(value).shiftLeft(offset).testBit(n)
- Parameters:
value
- 53-bit value.offset
- Left shift offset.n
- Index of bit to test.- Returns:
- true if the bit is 1
-
checkWeightsNonZeroLength
private static int checkWeightsNonZeroLength(double[] weights)
Check the weights have a non-zero length.- Parameters:
weights
- Weights.- Returns:
- the length
-
indicesOfNonZero
private static int[] indicesOfNonZero(long[] values)
Create the indices of non-zero values.- Parameters:
values
- Values.- Returns:
- the indices
-
indexOfNonZero
static int indexOfNonZero(long[] frequencies)
Find the index of the first non-zero frequency.- Parameters:
frequencies
- Frequencies.- Returns:
- the index
- Throws:
java.lang.IllegalStateException
- if all frequencies are zero.
-
checkArraySize
static int checkArraySize(long size)
Check the size is valid for a 1D array.- Parameters:
size
- Size- Returns:
- the size as an
int
- Throws:
java.lang.IllegalArgumentException
- if the size is too large for a 1D array.
-
-