mlpack 3.4.2
gan.hpp
Go to the documentation of this file.
1
11#ifndef MLPACK_METHODS_ANN_GAN_GAN_HPP
12#define MLPACK_METHODS_ANN_GAN_GAN_HPP
13
14#include <mlpack/core.hpp>
15
23
24
25namespace mlpack {
26namespace ann {
27
57template<
58 typename Model,
59 typename InitializationRuleType,
60 typename Noise,
61 typename PolicyType = StandardGAN
62>
63class GAN
64{
65 public:
83 GAN(Model generator,
84 Model discriminator,
85 InitializationRuleType& initializeRule,
86 Noise& noiseFunction,
87 const size_t noiseDim,
88 const size_t batchSize,
89 const size_t generatorUpdateStep,
90 const size_t preTrainSize,
91 const double multiplier,
92 const double clippingParameter = 0.01,
93 const double lambda = 10.0);
94
96 GAN(const GAN&);
97
99 GAN(GAN&&);
100
107 void ResetData(arma::mat trainData);
108
109 // Reset function.
110 void Reset();
111
123 template<typename OptimizerType, typename... CallbackTypes>
124 double Train(arma::mat trainData,
125 OptimizerType& Optimizer,
126 CallbackTypes&&... callbacks);
127
137 template<typename Policy = PolicyType>
138 typename std::enable_if<std::is_same<Policy, StandardGAN>::value ||
139 std::is_same<Policy, DCGAN>::value, double>::type
140 Evaluate(const arma::mat& parameters,
141 const size_t i,
142 const size_t batchSize);
143
152 template<typename Policy = PolicyType>
153 typename std::enable_if<std::is_same<Policy, WGAN>::value,
154 double>::type
155 Evaluate(const arma::mat& parameters,
156 const size_t i,
157 const size_t batchSize);
158
167 template<typename Policy = PolicyType>
168 typename std::enable_if<std::is_same<Policy, WGANGP>::value,
169 double>::type
170 Evaluate(const arma::mat& parameters,
171 const size_t i,
172 const size_t batchSize);
173
184 template<typename GradType, typename Policy = PolicyType>
185 typename std::enable_if<std::is_same<Policy, StandardGAN>::value ||
186 std::is_same<Policy, DCGAN>::value, double>::type
187 EvaluateWithGradient(const arma::mat& parameters,
188 const size_t i,
189 GradType& gradient,
190 const size_t batchSize);
191
202 template<typename GradType, typename Policy = PolicyType>
203 typename std::enable_if<std::is_same<Policy, WGAN>::value,
204 double>::type
205 EvaluateWithGradient(const arma::mat& parameters,
206 const size_t i,
207 GradType& gradient,
208 const size_t batchSize);
209
220 template<typename GradType, typename Policy = PolicyType>
221 typename std::enable_if<std::is_same<Policy, WGANGP>::value,
222 double>::type
223 EvaluateWithGradient(const arma::mat& parameters,
224 const size_t i,
225 GradType& gradient,
226 const size_t batchSize);
227
238 template<typename Policy = PolicyType>
239 typename std::enable_if<std::is_same<Policy, StandardGAN>::value ||
240 std::is_same<Policy, DCGAN>::value, void>::type
241 Gradient(const arma::mat& parameters,
242 const size_t i,
243 arma::mat& gradient,
244 const size_t batchSize);
245
256 template<typename Policy = PolicyType>
257 typename std::enable_if<std::is_same<Policy, WGAN>::value, void>::type
258 Gradient(const arma::mat& parameters,
259 const size_t i,
260 arma::mat& gradient,
261 const size_t batchSize);
262
273 template<typename Policy = PolicyType>
274 typename std::enable_if<std::is_same<Policy, WGANGP>::value,
275 void>::type
276 Gradient(const arma::mat& parameters,
277 const size_t i,
278 arma::mat& gradient,
279 const size_t batchSize);
280
285 void Shuffle();
286
292 void Forward(const arma::mat& input);
293
300 void Predict(arma::mat input, arma::mat& output);
301
303 const arma::mat& Parameters() const { return parameter; }
305 arma::mat& Parameters() { return parameter; }
306
308 const Model& Generator() const { return generator; }
310 Model& Generator() { return generator; }
312 const Model& Discriminator() const { return discriminator; }
314 Model& Discriminator() { return discriminator; }
315
317 size_t NumFunctions() const { return numFunctions; }
318
320 const arma::mat& Responses() const { return responses; }
322 arma::mat& Responses() { return responses; }
323
325 const arma::mat& Predictors() const { return predictors; }
327 arma::mat& Predictors() { return predictors; }
328
330 template<typename Archive>
331 void serialize(Archive& ar, const unsigned int /* version */);
332
333 private:
338 void ResetDeterministic();
339
341 arma::mat predictors;
343 arma::mat parameter;
345 Model generator;
347 Model discriminator;
349 InitializationRuleType initializeRule;
351 Noise noiseFunction;
353 size_t noiseDim;
355 size_t numFunctions;
357 size_t batchSize;
359 size_t currentBatch;
361 size_t generatorUpdateStep;
363 size_t preTrainSize;
365 double multiplier;
367 double clippingParameter;
369 double lambda;
371 bool reset;
373 DeltaVisitor deltaVisitor;
375 arma::mat responses;
377 arma::mat currentInput;
379 arma::mat currentTarget;
381 OutputParameterVisitor outputParameterVisitor;
383 WeightSizeVisitor weightSizeVisitor;
385 ResetVisitor resetVisitor;
387 arma::mat gradient;
389 arma::mat gradientDiscriminator;
391 arma::mat noiseGradientDiscriminator;
393 arma::mat normGradientDiscriminator;
395 arma::mat noise;
397 arma::mat gradientGenerator;
399 bool deterministic;
401 size_t genWeights;
403 size_t discWeights;
404};
405
406} // namespace ann
407} // namespace mlpack
408
409// Include implementation.
410#include "gan_impl.hpp"
411#include "wgan_impl.hpp"
412#include "wgangp_impl.hpp"
413
414
415#endif
DeltaVisitor exposes the delta parameter of the given module.
The implementation of the standard GAN module.
Definition: gan.hpp:64
GAN(GAN &&)
Move constructor.
GAN(const GAN &)
Copy constructor.
void Forward(const arma::mat &input)
This function does a forward pass through the GAN network.
arma::mat & Parameters()
Modify the parameters of the network.
Definition: gan.hpp:305
std::enable_if< std::is_same< Policy, WGAN >::value, double >::type Evaluate(const arma::mat &parameters, const size_t i, const size_t batchSize)
Evaluate function for the WGAN.
Model & Generator()
Modify the generator of the GAN.
Definition: gan.hpp:310
std::enable_if< std::is_same< Policy, WGANGP >::value, void >::type Gradient(const arma::mat &parameters, const size_t i, arma::mat &gradient, const size_t batchSize)
Gradient function for WGAN-GP.
void Predict(arma::mat input, arma::mat &output)
This function predicts the output of the network on the given input.
const arma::mat & Parameters() const
Return the parameters of the network.
Definition: gan.hpp:303
size_t NumFunctions() const
Return the number of separable functions (the number of predictor points).
Definition: gan.hpp:317
void Shuffle()
Shuffle the order of function visitation.
std::enable_if< std::is_same< Policy, WGAN >::value, double >::type EvaluateWithGradient(const arma::mat &parameters, const size_t i, GradType &gradient, const size_t batchSize)
EvaluateWithGradient function for the WGAN.
std::enable_if< std::is_same< Policy, WGAN >::value, void >::type Gradient(const arma::mat &parameters, const size_t i, arma::mat &gradient, const size_t batchSize)
Gradient function for WGAN.
const Model & Discriminator() const
Return the discriminator of the GAN.
Definition: gan.hpp:312
double Train(arma::mat trainData, OptimizerType &Optimizer, CallbackTypes &&... callbacks)
Train function.
const Model & Generator() const
Return the generator of the GAN.
Definition: gan.hpp:308
std::enable_if< std::is_same< Policy, WGANGP >::value, double >::type Evaluate(const arma::mat &parameters, const size_t i, const size_t batchSize)
Evaluate function for the WGAN-GP.
void ResetData(arma::mat trainData)
Initialize the generator, discriminator and weights of the model for training.
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, void >::type Gradient(const arma::mat &parameters, const size_t i, arma::mat &gradient, const size_t batchSize)
Gradient function for Standard GAN and DCGAN.
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, double >::type Evaluate(const arma::mat &parameters, const size_t i, const size_t batchSize)
Evaluate function for the Standard GAN and DCGAN.
const arma::mat & Predictors() const
Get the matrix of data points (predictors).
Definition: gan.hpp:325
arma::mat & Predictors()
Modify the matrix of data points (predictors).
Definition: gan.hpp:327
Model & Discriminator()
Modify the discriminator of the GAN.
Definition: gan.hpp:314
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, double >::type EvaluateWithGradient(const arma::mat &parameters, const size_t i, GradType &gradient, const size_t batchSize)
EvaluateWithGradient function for the Standard GAN and DCGAN.
const arma::mat & Responses() const
Get the matrix of responses to the input data points.
Definition: gan.hpp:320
arma::mat & Responses()
Modify the matrix of responses to the input data points.
Definition: gan.hpp:322
void serialize(Archive &ar, const unsigned int)
Serialize the model.
std::enable_if< std::is_same< Policy, WGANGP >::value, double >::type EvaluateWithGradient(const arma::mat &parameters, const size_t i, GradType &gradient, const size_t batchSize)
EvaluateWithGradient function for the WGAN-GP.
GAN(Model generator, Model discriminator, InitializationRuleType &initializeRule, Noise &noiseFunction, const size_t noiseDim, const size_t batchSize, const size_t generatorUpdateStep, const size_t preTrainSize, const double multiplier, const double clippingParameter=0.01, const double lambda=10.0)
Constructor for GAN class.
OutputParameterVisitor exposes the output parameter of the given module.
ResetVisitor executes the Reset() function.
WeightSizeVisitor returns the number of weights of the given module.
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1