11#ifndef MLPACK_METHODS_ANN_GAN_GAN_HPP
12#define MLPACK_METHODS_ANN_GAN_GAN_HPP
59 typename InitializationRuleType,
61 typename PolicyType = StandardGAN
85 InitializationRuleType& initializeRule,
87 const size_t noiseDim,
88 const size_t batchSize,
89 const size_t generatorUpdateStep,
90 const size_t preTrainSize,
91 const double multiplier,
92 const double clippingParameter = 0.01,
93 const double lambda = 10.0);
123 template<
typename OptimizerType,
typename... CallbackTypes>
125 OptimizerType& Optimizer,
126 CallbackTypes&&... callbacks);
137 template<
typename Policy = PolicyType>
138 typename std::enable_if<std::is_same<Policy, StandardGAN>::value ||
139 std::is_same<Policy, DCGAN>::value,
double>::type
142 const size_t batchSize);
152 template<
typename Policy = PolicyType>
153 typename std::enable_if<std::is_same<Policy, WGAN>::value,
157 const size_t batchSize);
167 template<
typename Policy = PolicyType>
168 typename std::enable_if<std::is_same<Policy, WGANGP>::value,
172 const size_t batchSize);
184 template<
typename GradType,
typename Policy = PolicyType>
185 typename std::enable_if<std::is_same<Policy, StandardGAN>::value ||
186 std::is_same<Policy, DCGAN>::value,
double>::type
190 const size_t batchSize);
202 template<
typename GradType,
typename Policy = PolicyType>
203 typename std::enable_if<std::is_same<Policy, WGAN>::value,
208 const size_t batchSize);
220 template<
typename GradType,
typename Policy = PolicyType>
221 typename std::enable_if<std::is_same<Policy, WGANGP>::value,
226 const size_t batchSize);
238 template<
typename Policy = PolicyType>
239 typename std::enable_if<std::is_same<Policy, StandardGAN>::value ||
240 std::is_same<Policy, DCGAN>::value,
void>::type
244 const size_t batchSize);
256 template<
typename Policy = PolicyType>
257 typename std::enable_if<std::is_same<Policy, WGAN>::value,
void>::type
261 const size_t batchSize);
273 template<
typename Policy = PolicyType>
274 typename std::enable_if<std::is_same<Policy, WGANGP>::value,
279 const size_t batchSize);
300 void Predict(arma::mat input, arma::mat& output);
320 const arma::mat&
Responses()
const {
return responses; }
330 template<
typename Archive>
338 void ResetDeterministic();
341 arma::mat predictors;
349 InitializationRuleType initializeRule;
361 size_t generatorUpdateStep;
367 double clippingParameter;
377 arma::mat currentInput;
379 arma::mat currentTarget;
389 arma::mat gradientDiscriminator;
391 arma::mat noiseGradientDiscriminator;
393 arma::mat normGradientDiscriminator;
397 arma::mat gradientGenerator;
410#include "gan_impl.hpp"
411#include "wgan_impl.hpp"
412#include "wgangp_impl.hpp"
DeltaVisitor exposes the delta parameter of the given module.
The implementation of the standard GAN module.
GAN(GAN &&)
Move constructor.
GAN(const GAN &)
Copy constructor.
void Forward(const arma::mat &input)
This function does a forward pass through the GAN network.
arma::mat & Parameters()
Modify the parameters of the network.
std::enable_if< std::is_same< Policy, WGAN >::value, double >::type Evaluate(const arma::mat ¶meters, const size_t i, const size_t batchSize)
Evaluate function for the WGAN.
Model & Generator()
Modify the generator of the GAN.
std::enable_if< std::is_same< Policy, WGANGP >::value, void >::type Gradient(const arma::mat ¶meters, const size_t i, arma::mat &gradient, const size_t batchSize)
Gradient function for WGAN-GP.
void Predict(arma::mat input, arma::mat &output)
This function predicts the output of the network on the given input.
const arma::mat & Parameters() const
Return the parameters of the network.
size_t NumFunctions() const
Return the number of separable functions (the number of predictor points).
void Shuffle()
Shuffle the order of function visitation.
std::enable_if< std::is_same< Policy, WGAN >::value, double >::type EvaluateWithGradient(const arma::mat ¶meters, const size_t i, GradType &gradient, const size_t batchSize)
EvaluateWithGradient function for the WGAN.
std::enable_if< std::is_same< Policy, WGAN >::value, void >::type Gradient(const arma::mat ¶meters, const size_t i, arma::mat &gradient, const size_t batchSize)
Gradient function for WGAN.
const Model & Discriminator() const
Return the discriminator of the GAN.
double Train(arma::mat trainData, OptimizerType &Optimizer, CallbackTypes &&... callbacks)
Train function.
const Model & Generator() const
Return the generator of the GAN.
std::enable_if< std::is_same< Policy, WGANGP >::value, double >::type Evaluate(const arma::mat ¶meters, const size_t i, const size_t batchSize)
Evaluate function for the WGAN-GP.
void ResetData(arma::mat trainData)
Initialize the generator, discriminator and weights of the model for training.
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, void >::type Gradient(const arma::mat ¶meters, const size_t i, arma::mat &gradient, const size_t batchSize)
Gradient function for Standard GAN and DCGAN.
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, double >::type Evaluate(const arma::mat ¶meters, const size_t i, const size_t batchSize)
Evaluate function for the Standard GAN and DCGAN.
const arma::mat & Predictors() const
Get the matrix of data points (predictors).
arma::mat & Predictors()
Modify the matrix of data points (predictors).
Model & Discriminator()
Modify the discriminator of the GAN.
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, double >::type EvaluateWithGradient(const arma::mat ¶meters, const size_t i, GradType &gradient, const size_t batchSize)
EvaluateWithGradient function for the Standard GAN and DCGAN.
const arma::mat & Responses() const
Get the matrix of responses to the input data points.
arma::mat & Responses()
Modify the matrix of responses to the input data points.
void serialize(Archive &ar, const unsigned int)
Serialize the model.
std::enable_if< std::is_same< Policy, WGANGP >::value, double >::type EvaluateWithGradient(const arma::mat ¶meters, const size_t i, GradType &gradient, const size_t batchSize)
EvaluateWithGradient function for the WGAN-GP.
GAN(Model generator, Model discriminator, InitializationRuleType &initializeRule, Noise &noiseFunction, const size_t noiseDim, const size_t batchSize, const size_t generatorUpdateStep, const size_t preTrainSize, const double multiplier, const double clippingParameter=0.01, const double lambda=10.0)
Constructor for GAN class.
OutputParameterVisitor exposes the output parameter of the given module.
ResetVisitor executes the Reset() function.
WeightSizeVisitor returns the number of weights of the given module.
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
Linear algebra utility functions, generally performed on matrices or vectors.