public class BackpropagationTrainer extends Object implements Trainer, Serializable
Modifier and Type | Field and Description |
---|---|
static String |
PROP_BATCH_MODE
Name of the batchMode property
|
static String |
PROP_BATCH_SIZE
Name of the batchSize property
|
static String |
PROP_LEARNING_RATE
Name of the learningRate property
|
static String |
PROP_MAX_EPOCHS
Name of the maxEpochs property
|
static String |
PROP_MAX_ERROR
Name of the maxError property
|
static String |
PROP_MOMENTUM
Name of the momentum property
|
static String |
PROP_OPTIMIZER_TYPE
Name of the optimizer property
|
Constructor and Description |
---|
BackpropagationTrainer(NeuralNetwork neuralNet)
Creates an instance of BackpropagationTrainer for the given neural network to train.
|
BackpropagationTrainer(Properties prop)
Creates an instance of BackpropagationTrainer with the given properties.
|
Modifier and Type | Method and Description |
---|---|
void |
addListener(TrainingListener listener)
Adds training listener to this trainer.
|
boolean |
createsTrainingSnaphots()
Returns true if network creates training snapshots, false otherwise.
|
int |
getBatchSize()
Batch size is number of training examples after which network's weights are adjusted.
|
int |
getCheckpointEpochs()
On how many epochs the snapshots of the trained network should be created.
|
int |
getCurrentEpoch()
Returns the current training epoch(iteration) of this trainer.
|
float |
getDropout()
Dropout is a technique to prevent overfitting, which skips adjusting weights for some neurons with given probability.
|
boolean |
getEarlyStopping()
Early stopping stops training if it starts converging slow, and prevents overfitting.
|
float |
getEarlyStoppingMinLossChange()
Early stopping stops training if the error/loss start converging to slow.
|
int |
getEarlyStoppingPatience()
How many epochs to wait to see if the loss is lowering to slow.
|
boolean |
getExtendedLogging()
Extended logging includes additional info for debugging the training.
|
float |
getLearningRate()
Learning rate controls the step size as a percent of the error to use
for adjusting internal parameters(weights) of the neural network.
|
long |
getMaxEpochs()
Returns the setting for maximum number of training epochs(iterations).
|
float |
getMaxError()
Returns the setting for the stopping error threshold.
|
float |
getMomentum()
Momentum settings helps to avoid oscillations in weight changes and get more stable and faster training.
|
NeuralNetwork<?> |
getNeuralNetwork()
Returns a neural network trained by this trainer.
|
OptimizerType |
getOptimizer() |
boolean |
getShuffle()
Returns shuffle flag which determines if training set should be shuffled before each epoch.
|
int |
getSnapshotEpochs()
On how many epochs to make training snapshots.
|
String |
getSnapshotPath()
Path to use for making snapshots - saving the current state of trained network during
the training in order to be able to restore it from a training point if needed.
|
float |
getStopAccuracy() |
float |
getStopError()
Alias for getMaxError().
|
javax.visrec.ml.data.DataSet<?> |
getTestSet()
Test set is used after the training to estimate performance of the trained model and generalization ability with new data.
|
float |
getTrainingAccuracy()
Accuracy metric which tells us a percent of correct predictions for training set.
|
float |
getTrainingLoss()
Total training error/loss at the current epoch.
|
float |
getValidationAccuracy()
Accuracy metric which tells us a percent of correct predictions for validation set.
|
float |
getValidationLoss()
Validation loss is an error calculated using validation set, used to prevent overfitting, and validate architecture and training settings.
|
boolean |
isBatchMode()
In batch mode weights are adjusted after the pass of all examples from the training set,
while in online mode weights are adjusted after each training example.
|
void |
removeListener(TrainingListener listener)
Removes training listener from this trainer.
|
BackpropagationTrainer |
setBatchMode(boolean batchMode)
Sets flag whether to use batch mode during the training.
|
BackpropagationTrainer |
setBatchSize(int batchSize)
Batch size is number of training examples after which network's weights are adjusted.
|
BackpropagationTrainer |
setCheckpointEpochs(int checkpointEpochs)
On how many epochs the snapshots of the trained network should be created.
|
BackpropagationTrainer |
setDropout(float dropout)
Dropout is a technique to prevent overfitting, which skips adjusting weights for some neurons with given probability.
|
BackpropagationTrainer |
setEarlyStopping(boolean earlyStopping)
Early stopping stops training if it starts converging slow, and prevents overfitting.
|
BackpropagationTrainer |
setEarlyStoppingMinLossChange(float earlyStoppingMinLossChange)
Early stopping stops training if the error/loss start converging to slow.
|
BackpropagationTrainer |
setEarlyStoppingPatience(int earlyStoppingPatience)
How many epochs to wait to see if the loss is lowering to slow.
|
void |
setExtendedLogging(boolean extendedLogging)
Extended logging includes additional info for debugging the training.
|
BackpropagationTrainer |
setL1Regularization(float regL1)
L1 regularization (sum of abs values) is used to prevent overfitting and too large weights.
|
BackpropagationTrainer |
setL2Regularization(float regL2)
L2 regularization (sum of squares) is used to prevent overfitting and too large weights.
|
BackpropagationTrainer |
setLearningRate(float learningRate)
Learning rate controls the step size as a percent of the error to use
for adjusting internal parameters(weights) of the neural network.
|
BackpropagationTrainer |
setLearningRateDecay(float decayRate)
Learning rate decay lowers the learning rate with each epoch by devayRate factor,
which may improve error lowering the error.
|
BackpropagationTrainer |
setMaxEpochs(long maxEpochs)
Deprecated.
Use setStopEpochs instead
|
BackpropagationTrainer |
setMaxError(float maxError)
Sets stopping error threshold for this training.
|
BackpropagationTrainer |
setMomentum(float momentum)
Momentum settings helps to avoid oscillations in weight changes and get more stable and faster training.
|
BackpropagationTrainer |
setOptimizer(OptimizerType optimizer) |
void |
setProperties(Properties prop)
Sets properties from available keys in specified prop object.
|
BackpropagationTrainer |
setShuffle(boolean shuffle)
Sets shuffle flag which determines if training set should be shuffled before each epoch.
|
void |
setSnapshotEpochs(int snapshotEpochs)
On how many epochs to make training snapshots.
|
BackpropagationTrainer |
setSnapshotPath(String snapshotPath)
Path to use for making snapshots - saving the current state of trained network during
the training in order to be able to restore it from a training point.
|
void |
setStopAccuracy(float stopAccuracy) |
BackpropagationTrainer |
setStopEpochs(long stopEpochs)
Sets number of epochs/iterations to run the training.
|
BackpropagationTrainer |
setStopError(float stopError)
The training stops when/if training error has reached this value.
|
void |
setTestSet(javax.visrec.ml.data.DataSet<MLDataItem> testSet)
Test set is used after the training to estimate performance of the trained model and generalization ability with new data.
|
void |
setTrainingSnapshots(boolean trainingSnapshots)
Training snapshots save the current state of the trained neural network during
the training in order to be able to restore it from a training point if needed.
|
void |
stop()
Stops the training.
|
void |
train(javax.visrec.ml.data.DataSet<?> trainingSet,
double valSplit)
Run training using given training set, and split part of it to use as a validation set.
|
void |
train(javax.visrec.ml.data.DataSet<? extends MLDataItem> trainingSet)
Runs training using specified training set.
|
void |
train(javax.visrec.ml.data.DataSet<MLDataItem> trainingSet,
javax.visrec.ml.data.DataSet<MLDataItem> validationSet)
Runs training using given training and validation sets.
|
void |
updateLearningRate(float learningRate)
Updates learning rate for all layers during the learning rate decay.
|
public static final String PROP_MAX_ERROR
public static final String PROP_MAX_EPOCHS
public static final String PROP_LEARNING_RATE
public static final String PROP_MOMENTUM
public static final String PROP_BATCH_MODE
public static final String PROP_BATCH_SIZE
public static final String PROP_OPTIMIZER_TYPE
public BackpropagationTrainer(NeuralNetwork neuralNet)
neuralNet
- neural network to train using this instance of backpropagation algorithmpublic BackpropagationTrainer(Properties prop)
prop
- key,value pairs of properties for backpropagationpublic void train(javax.visrec.ml.data.DataSet<MLDataItem> trainingSet, javax.visrec.ml.data.DataSet<MLDataItem> validationSet)
trainingSet
- set of example data to train the networkvalidationSet
- set of example data to validate the network during the trainingpublic void train(javax.visrec.ml.data.DataSet<?> trainingSet, double valSplit)
trainingSet
- set of example data to train the networkvalSplit
- percent of training set to use as a validation set, value between 0 and 1, commonly something like 0.1, 0.2public void train(javax.visrec.ml.data.DataSet<? extends MLDataItem> trainingSet)
public long getMaxEpochs()
public BackpropagationTrainer setMaxEpochs(long maxEpochs)
maxEpochs
- the maximum number of training epochs(iterations) for training the networkstopError
public BackpropagationTrainer setStopEpochs(long stopEpochs)
stopEpochs
- number of epochs after which training will stopstopError
public float getMaxError()
public float getStopError()
public BackpropagationTrainer setMaxError(float maxError)
maxError
- maximum error thresholdpublic BackpropagationTrainer setStopError(float stopError)
stopError
- value of training error to stop the trainingpublic float getStopAccuracy()
public void setStopAccuracy(float stopAccuracy)
public BackpropagationTrainer setLearningRate(float learningRate)
learningRate
- a value in range (0, 1), where 0.01 is being used as a default initial valuepublic float getLearningRate()
public NeuralNetwork<?> getNeuralNetwork()
public void updateLearningRate(float learningRate)
learningRate
- a value of learning rate to set for all layersLearningRateDecay
public BackpropagationTrainer setLearningRateDecay(float decayRate)
decayRate
- public BackpropagationTrainer setL2Regularization(float regL2)
regL2
- coefficient for L2 regularizationpublic BackpropagationTrainer setL1Regularization(float regL1)
regL1
- coefficient for L1 regularizationpublic boolean getShuffle()
public BackpropagationTrainer setShuffle(boolean shuffle)
shuffle
- public void addListener(TrainingListener listener)
listener
- object that listens for the events in this trainerpublic void removeListener(TrainingListener listener)
listener
- listener to removepublic boolean isBatchMode()
setBatchMode(boolean)
public BackpropagationTrainer setBatchMode(boolean batchMode)
batchMode
- public int getBatchSize()
public BackpropagationTrainer setBatchSize(int batchSize)
batchSize
- public BackpropagationTrainer setMomentum(float momentum)
momentum
- a decimal value greater than zero and less than onepublic float getMomentum()
public void stop()
public float getTrainingLoss()
public float getValidationLoss()
public float getTrainingAccuracy()
public float getValidationAccuracy()
public int getCurrentEpoch()
public OptimizerType getOptimizer()
public BackpropagationTrainer setOptimizer(OptimizerType optimizer)
public javax.visrec.ml.data.DataSet<?> getTestSet()
public void setTestSet(javax.visrec.ml.data.DataSet<MLDataItem> testSet)
testSet
- example data not used during the training, that will be used for evaluation/testing of the trained modelpublic boolean getEarlyStopping()
public BackpropagationTrainer setEarlyStopping(boolean earlyStopping)
earlyStopping
- public BackpropagationTrainer setSnapshotPath(String snapshotPath)
snapshotPath
- public String getSnapshotPath()
public int getSnapshotEpochs()
public void setSnapshotEpochs(int snapshotEpochs)
snapshotEpochs
- public void setTrainingSnapshots(boolean trainingSnapshots)
trainingSnapshots
- public boolean createsTrainingSnaphots()
public float getEarlyStoppingMinLossChange()
public BackpropagationTrainer setEarlyStoppingMinLossChange(float earlyStoppingMinLossChange)
earlyStoppingMinLossChange
- public int getEarlyStoppingPatience()
public BackpropagationTrainer setEarlyStoppingPatience(int earlyStoppingPatience)
earlyStoppingPatience
- public int getCheckpointEpochs()
public BackpropagationTrainer setCheckpointEpochs(int checkpointEpochs)
checkpointEpochs
- public final void setProperties(Properties prop)
prop
- public BackpropagationTrainer setDropout(float dropout)
dropout
- value between 0.2 and 0.8 which represents probability to skip adjusting weightspublic float getDropout()
public boolean getExtendedLogging()
public void setExtendedLogging(boolean extendedLogging)
extendedLogging
- Copyright © 2022. All rights reserved.