mlpack 3.4.2
rbm.hpp
Go to the documentation of this file.
1
11#ifndef MLPACK_METHODS_ANN_RBM_RBM_HPP
12#define MLPACK_METHODS_ANN_RBM_RBM_HPP
13
14#include <mlpack/core.hpp>
16
17namespace mlpack {
18namespace ann {
19
33template<
34 typename InitializationRuleType,
35 typename DataType = arma::mat,
36 typename PolicyType = BinaryRBM
37>
38class RBM
39{
40 public:
42 typedef typename DataType::elem_type ElemType;
43
60 RBM(arma::Mat<ElemType> predictors,
61 InitializationRuleType initializeRule,
62 const size_t visibleSize,
63 const size_t hiddenSize,
64 const size_t batchSize = 1,
65 const size_t numSteps = 1,
66 const size_t negSteps = 1,
67 const size_t poolSize = 2,
68 const ElemType slabPenalty = 8,
69 const ElemType radius = 1,
70 const bool persistence = false);
71
72 // Reset the network.
73 template<typename Policy = PolicyType, typename InputType = DataType>
74 typename std::enable_if<std::is_same<Policy, BinaryRBM>::value, void>::type
76
77 // Reset the network.
78 template<typename Policy = PolicyType, typename InputType = DataType>
79 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, void>::type
81
97 template<typename OptimizerType, typename... CallbackType>
98 double Train(OptimizerType& optimizer, CallbackType&&... callbacks);
99
108 double Evaluate(const arma::Mat<ElemType>& parameters,
109 const size_t i,
110 const size_t batchSize);
111
119 template<typename Policy = PolicyType, typename InputType = DataType>
120 typename std::enable_if<std::is_same<Policy, BinaryRBM>::value, double>::type
121 FreeEnergy(const arma::Mat<ElemType>& input);
122
133 template<typename Policy = PolicyType, typename InputType = DataType>
134 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
135 double>::type
136 FreeEnergy(const arma::Mat<ElemType>& input);
137
144 template<typename Policy = PolicyType, typename InputType = DataType>
145 typename std::enable_if<std::is_same<Policy, BinaryRBM>::value, void>::type
146 Phase(const InputType& input, DataType& gradient);
147
154 template<typename Policy = PolicyType, typename InputType = DataType>
155 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, void>::type
156 Phase(const InputType& input, DataType& gradient);
157
165 template<typename Policy = PolicyType, typename InputType = DataType>
166 typename std::enable_if<std::is_same<Policy, BinaryRBM>::value, void>::type
167 SampleHidden(const arma::Mat<ElemType>& input, arma::Mat<ElemType>& output);
168
179 template<typename Policy = PolicyType, typename InputType = DataType>
180 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, void>::type
181 SampleHidden(const arma::Mat<ElemType>& input, arma::Mat<ElemType>& output);
182
190 template<typename Policy = PolicyType, typename InputType = DataType>
191 typename std::enable_if<std::is_same<Policy, BinaryRBM>::value, void>::type
192 SampleVisible(arma::Mat<ElemType>& input, arma::Mat<ElemType>& output);
193
204 template<typename Policy = PolicyType, typename InputType = DataType>
205 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, void>::type
206 SampleVisible(arma::Mat<ElemType>& input, arma::Mat<ElemType>& output);
207
214 template<typename Policy = PolicyType, typename InputType = DataType>
215 typename std::enable_if<std::is_same<Policy, BinaryRBM>::value, void>::type
216 VisibleMean(InputType& input, DataType& output);
217
226 template<typename Policy = PolicyType, typename InputType = DataType>
227 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, void>::type
228 VisibleMean(InputType& input, DataType& output);
229
236 template<typename Policy = PolicyType, typename InputType = DataType>
237 typename std::enable_if<std::is_same<Policy, BinaryRBM>::value, void>::type
238 HiddenMean(const InputType& input, DataType& output);
239
250 template<typename Policy = PolicyType, typename InputType = DataType>
251 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, void>::type
252 HiddenMean(const InputType& input, DataType& output);
253
262 template<typename Policy = PolicyType, typename InputType = DataType>
263 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, void>::type
264 SpikeMean(const InputType& visible, DataType& spikeMean);
265
271 template<typename Policy = PolicyType, typename InputType = DataType>
272 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, void>::type
273 SampleSpike(InputType& spikeMean, DataType& spike);
274
284 template<typename Policy = PolicyType, typename InputType = DataType>
285 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, void>::type
286 SlabMean(const DataType& visible, DataType& spike, DataType& slabMean);
287
298 template<typename Policy = PolicyType, typename InputType = DataType>
299 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, void>::type
300 SampleSlab(InputType& slabMean, DataType& slab);
301
309 void Gibbs(const arma::Mat<ElemType>& input,
310 arma::Mat<ElemType>& output,
311 const size_t steps = SIZE_MAX);
312
321 void Gradient(const arma::Mat<ElemType>& parameters,
322 const size_t i,
323 arma::Mat<ElemType>& gradient,
324 const size_t batchSize);
325
330 void Shuffle();
331
333 size_t NumFunctions() const { return numFunctions; }
334
336 size_t NumSteps() const { return numSteps; }
337
339 const arma::Mat<ElemType>& Parameters() const { return parameter; }
341 arma::Mat<ElemType>& Parameters() { return parameter; }
342
344 arma::Cube<ElemType> const& Weight() const { return weight; }
346 arma::Cube<ElemType>& Weight() { return weight; }
347
349 DataType const& VisibleBias() const { return visibleBias; }
351 DataType& VisibleBias() { return visibleBias; }
352
354 DataType const& HiddenBias() const { return hiddenBias; }
356 DataType& HiddenBias() { return hiddenBias; }
357
359 DataType const& SpikeBias() const { return spikeBias; }
361 DataType& SpikeBias() { return spikeBias; }
362
364 ElemType const& SlabPenalty() const { return 1.0 / slabPenalty; }
365
367 DataType const& VisiblePenalty() const { return visiblePenalty; }
369 DataType& VisiblePenalty() { return visiblePenalty; }
370
372 size_t const& VisibleSize() const { return visibleSize; }
374 size_t const& HiddenSize() const { return hiddenSize; }
376 size_t const& PoolSize() const { return poolSize; }
377
379 template<typename Archive>
380 void serialize(Archive& ar, const unsigned int /* version */);
381
382 private:
384 arma::Mat<ElemType> parameter;
386 arma::Mat<ElemType> predictors;
387 // Initializer for initializing the weights of the network.
388 InitializationRuleType initializeRule;
390 arma::Mat<ElemType> state;
392 size_t numFunctions;
394 size_t visibleSize;
396 size_t hiddenSize;
398 size_t batchSize;
400 size_t numSteps;
402 size_t negSteps;
404 size_t poolSize;
406 size_t steps;
408 arma::Cube<ElemType> weight;
410 DataType visibleBias;
412 DataType hiddenBias;
414 DataType preActivation;
416 DataType spikeBias;
418 DataType visiblePenalty;
420 DataType visibleMean;
422 DataType spikeMean;
424 DataType spikeSamples;
426 DataType slabMean;
428 ElemType slabPenalty;
430 ElemType radius;
432 arma::Mat<ElemType> hiddenReconstruction;
434 arma::Mat<ElemType> visibleReconstruction;
436 arma::Mat<ElemType> negativeSamples;
438 arma::Mat<ElemType> negativeGradient;
440 arma::Mat<ElemType> tempNegativeGradient;
442 arma::Mat<ElemType> positiveGradient;
444 arma::Mat<ElemType> gibbsTemporary;
446 bool persistence;
448 bool reset;
449};
450
451} // namespace ann
452} // namespace mlpack
453
454#include "rbm_impl.hpp"
455#include "spike_slab_rbm_impl.hpp"
456
457#endif
The implementation of the RBM module.
Definition: rbm.hpp:39
std::enable_if< std::is_same< Policy, SpikeSlabRBM >::value, double >::type FreeEnergy(const arma::Mat< ElemType > &input)
This function calculates the free energy of the SpikeSlabRBM.
size_t const & PoolSize() const
Get the pool size.
Definition: rbm.hpp:376
DataType & VisibleBias()
Modify the visible bias of the network.
Definition: rbm.hpp:351
DataType & VisiblePenalty()
Modify the regularizer associated with visible variables.
Definition: rbm.hpp:369
arma::Mat< ElemType > & Parameters()
Modify the parameters of the network.
Definition: rbm.hpp:341
double Evaluate(const arma::Mat< ElemType > &parameters, const size_t i, const size_t batchSize)
Evaluate the RBM network with the given parameters.
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type VisibleMean(InputType &input, DataType &output)
The function calculates the mean for the visible layer.
DataType const & SpikeBias() const
Get the regularizer associated with spike variables.
Definition: rbm.hpp:359
size_t const & HiddenSize() const
Get the hidden size.
Definition: rbm.hpp:374
DataType const & VisibleBias() const
Return the visible bias of the network.
Definition: rbm.hpp:349
double Train(OptimizerType &optimizer, CallbackType &&... callbacks)
Train the RBM on the given input data.
size_t NumFunctions() const
Return the number of separable functions (the number of predictor points).
Definition: rbm.hpp:333
std::enable_if< std::is_same< Policy, SpikeSlabRBM >::value, void >::type SampleHidden(const arma::Mat< ElemType > &input, arma::Mat< ElemType > &output)
This function samples the slab outputs from the Normal distribution with mean given by: and variance...
void Shuffle()
Shuffle the order of function visitation.
std::enable_if< std::is_same< Policy, SpikeSlabRBM >::value, void >::type SpikeMean(const InputType &visible, DataType &spikeMean)
The function calculates the mean of the distribution P(h|v), where mean is given by: .
arma::Cube< ElemType > & Weight()
Modify the weights of the network.
Definition: rbm.hpp:346
std::enable_if< std::is_same< Policy, SpikeSlabRBM >::value, void >::type SlabMean(const DataType &visible, DataType &spike, DataType &slabMean)
The function calculates the mean of Normal distribution of P(s|v, h), where the mean is given by: .
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type HiddenMean(const InputType &input, DataType &output)
The function calculates the mean for the hidden layer.
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type SampleVisible(arma::Mat< ElemType > &input, arma::Mat< ElemType > &output)
This function samples the visible layer given the hidden layer using Bernoulli function.
std::enable_if< std::is_same< Policy, BinaryRBM >::value, double >::type FreeEnergy(const arma::Mat< ElemType > &input)
This function calculates the free energy of the BinaryRBM.
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type Phase(const InputType &input, DataType &gradient)
Calculates the gradient of the RBM network on the provided input.
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type Reset()
std::enable_if< std::is_same< Policy, SpikeSlabRBM >::value, void >::type HiddenMean(const InputType &input, DataType &output)
The function calculates the mean of the Normal distribution of P(s|v, h).
std::enable_if< std::is_same< Policy, SpikeSlabRBM >::value, void >::type VisibleMean(InputType &input, DataType &output)
The function calculates the mean of the Normal distribution of P(v|s, h).
void Gibbs(const arma::Mat< ElemType > &input, arma::Mat< ElemType > &output, const size_t steps=SIZE_MAX)
This function does the k-step Gibbs Sampling.
std::enable_if< std::is_same< Policy, SpikeSlabRBM >::value, void >::type SampleSlab(InputType &slabMean, DataType &slab)
The function samples from the Normal distribution of P(s|v, h), where the mean is given by: and vari...
std::enable_if< std::is_same< Policy, SpikeSlabRBM >::value, void >::type SampleVisible(arma::Mat< ElemType > &input, arma::Mat< ElemType > &output)
Sample Hidden function samples the slab outputs from the Normal distribution with mean given by: and...
size_t const & VisibleSize() const
Get the visible size.
Definition: rbm.hpp:372
arma::Cube< ElemType > const & Weight() const
Get the weights of the network.
Definition: rbm.hpp:344
void Gradient(const arma::Mat< ElemType > &parameters, const size_t i, arma::Mat< ElemType > &gradient, const size_t batchSize)
Calculates the gradients for the RBM network.
std::enable_if< std::is_same< Policy, SpikeSlabRBM >::value, void >::type SampleSpike(InputType &spikeMean, DataType &spike)
The function samples the spike function using Bernoulli distribution.
DataType & SpikeBias()
Modify the regularizer associated with spike variables.
Definition: rbm.hpp:361
DataType & HiddenBias()
Modify the hidden bias of the network.
Definition: rbm.hpp:356
const arma::Mat< ElemType > & Parameters() const
Return the parameters of the network.
Definition: rbm.hpp:339
size_t NumSteps() const
Return the number of steps of Gibbs Sampling.
Definition: rbm.hpp:336
std::enable_if< std::is_same< Policy, SpikeSlabRBM >::value, void >::type Phase(const InputType &input, DataType &gradient)
Calculates the gradient of the RBM network on the provided input.
DataType const & HiddenBias() const
Return the hidden bias of the network.
Definition: rbm.hpp:354
ElemType const & SlabPenalty() const
Get the regularizer associated with slab variables.
Definition: rbm.hpp:364
DataType::elem_type ElemType
Definition: rbm.hpp:42
DataType const & VisiblePenalty() const
Get the regularizer associated with visible variables.
Definition: rbm.hpp:367
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type SampleHidden(const arma::Mat< ElemType > &input, arma::Mat< ElemType > &output)
This function samples the hidden layer given the visible layer using Bernoulli function.
void serialize(Archive &ar, const unsigned int)
Serialize the model.
RBM(arma::Mat< ElemType > predictors, InitializationRuleType initializeRule, const size_t visibleSize, const size_t hiddenSize, const size_t batchSize=1, const size_t numSteps=1, const size_t negSteps=1, const size_t poolSize=2, const ElemType slabPenalty=8, const ElemType radius=1, const bool persistence=false)
Initialize all the parameters of the network using initializeRule.
std::enable_if< std::is_same< Policy, SpikeSlabRBM >::value, void >::type Reset()
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1