Package org.ojalgo.ann
Class NetworkTrainer
- java.lang.Object
-
- org.ojalgo.ann.WrappedANN
-
- org.ojalgo.ann.NetworkTrainer
-
- All Implemented Interfaces:
java.util.function.Supplier<ArtificialNeuralNetwork>
public final class NetworkTrainer extends WrappedANN
An Artificial Neural Network (ANN) builder/trainer.
-
-
Field Summary
Fields Modifier and Type Field Description private TrainingConfiguration
myConfiguration
private PhysicalStore<java.lang.Double>[]
myGradients
-
Constructor Summary
Constructors Constructor Description NetworkTrainer(ArtificialNeuralNetwork network, int batchSize)
-
Method Summary
All Methods Instance Methods Concrete Methods Deprecated Methods Modifier and Type Method Description NetworkTrainer
activator(int layer, ArtificialNeuralNetwork.Activator activator)
Deprecated.UseNetworkBuilder
andNetworkBuilder.layer(int, Activator)
instead.NetworkTrainer
activators(ArtificialNeuralNetwork.Activator activator)
Deprecated.UseNetworkBuilder
andNetworkBuilder.layer(int, Activator)
instead.NetworkTrainer
activators(ArtificialNeuralNetwork.Activator... activators)
Deprecated.UseNetworkBuilder
andNetworkBuilder.layer(int, Activator)
instead.NetworkTrainer
bias(int layer, int output, double bias)
NetworkTrainer
dropouts()
boolean
equals(java.lang.Object obj)
NetworkTrainer
error(ArtificialNeuralNetwork.Error error)
(package private) double
error(Access1D<?> target, Access1D<?> current)
int
hashCode()
NetworkTrainer
lasso(double factor)
L1 lasso regularisationDataBatch
newOutputBatch()
NetworkTrainer
rate(double rate)
NetworkTrainer
ridge(double factor)
L2 ridge regularisationStructure2D[]
structure()
java.lang.String
toString()
void
train(java.lang.Iterable<? extends Access1D<java.lang.Double>> givenInputs, java.lang.Iterable<? extends Access1D<java.lang.Double>> targetOutputs)
Deprecated.Just usetrain(Access1D, Access1D)
insteadvoid
train(Access1D<java.lang.Double> givenInput, Access1D<java.lang.Double> targetOutput)
The arguments are typed asAccess1D
but it's probably best to think of (create) them as something 2D where the number of rows should match the batch size and the number of columns the number of inputs and outputs respectively.NetworkTrainer
weight(int layer, int input, int output, double weight)
-
Methods inherited from class org.ojalgo.ann.WrappedANN
adjust, depth, get, getActivator, getBatchSize, getBias, getInput, getInput, getOutput, getOutput, getOutputActivator, getWeight, getWeights, invoke, newInputBatch, randomise, setActivator, setBias, setWeight
-
-
-
-
Field Detail
-
myConfiguration
private final TrainingConfiguration myConfiguration
-
myGradients
private final PhysicalStore<java.lang.Double>[] myGradients
-
-
Constructor Detail
-
NetworkTrainer
NetworkTrainer(ArtificialNeuralNetwork network, int batchSize)
-
-
Method Detail
-
activator
@Deprecated public NetworkTrainer activator(int layer, ArtificialNeuralNetwork.Activator activator)
Deprecated.UseNetworkBuilder
andNetworkBuilder.layer(int, Activator)
instead.- Parameters:
layer
- 0-based index among the calculation layers (excluding the input layer)activator
- The activator function to use
-
activators
@Deprecated public NetworkTrainer activators(ArtificialNeuralNetwork.Activator activator)
Deprecated.UseNetworkBuilder
andNetworkBuilder.layer(int, Activator)
instead.
-
activators
@Deprecated public NetworkTrainer activators(ArtificialNeuralNetwork.Activator... activators)
Deprecated.UseNetworkBuilder
andNetworkBuilder.layer(int, Activator)
instead.
-
bias
public NetworkTrainer bias(int layer, int output, double bias)
-
dropouts
public NetworkTrainer dropouts()
-
equals
public boolean equals(java.lang.Object obj)
- Overrides:
equals
in classWrappedANN
-
error
public NetworkTrainer error(ArtificialNeuralNetwork.Error error)
-
hashCode
public int hashCode()
- Overrides:
hashCode
in classWrappedANN
-
lasso
public NetworkTrainer lasso(double factor)
L1 lasso regularisation
-
newOutputBatch
public DataBatch newOutputBatch()
- Overrides:
newOutputBatch
in classWrappedANN
- See Also:
WrappedANN.newInputBatch()
-
rate
public NetworkTrainer rate(double rate)
-
ridge
public NetworkTrainer ridge(double factor)
L2 ridge regularisation
-
structure
public Structure2D[] structure()
- Overrides:
structure
in classWrappedANN
-
toString
public java.lang.String toString()
- Overrides:
toString
in classjava.lang.Object
-
train
public void train(Access1D<java.lang.Double> givenInput, Access1D<java.lang.Double> targetOutput)
The arguments are typed asAccess1D
but it's probably best to think of (create) them as something 2D where the number of rows should match the batch size and the number of columns the number of inputs and outputs respectively. When the batch size is 1 then the arguments can actually be 1D.- Parameters:
givenInput
- One or more input examples, depending on the batch sizetargetOutput
- One or more, matching, output targets
-
train
@Deprecated public void train(java.lang.Iterable<? extends Access1D<java.lang.Double>> givenInputs, java.lang.Iterable<? extends Access1D<java.lang.Double>> targetOutputs)
Deprecated.Just usetrain(Access1D, Access1D)
insteadNote that the requiredIterable
:s can be obtained from callingAccess2D.rows()
orAccess2D.columns()
on anything "2D".
-
weight
public NetworkTrainer weight(int layer, int input, int output, double weight)
-
-