Package deepnetts.net
Class NeuralNetwork<T extends Trainer>
java.lang.Object
deepnetts.net.NeuralNetwork<T>
- All Implemented Interfaces:
TrainerProvider<T>
,Serializable
- Direct Known Subclasses:
ConvolutionalNetwork
,FeedForwardNetwork
public class NeuralNetwork<T extends Trainer>
extends Object
implements TrainerProvider<T>, Serializable
Base class for all neural networks in Deep Netts.
Holds a list of abstract layers and loss function.
Provides methods for forward and backward calculation, and to access input and output layers.
Also provides network and output labels.
- See Also:
-
Method Summary
Modifier and TypeMethodDescriptionvoid
Applies calculated weight changes to all layers.void
backward()
Performs a backward bass across all layers in neural network, which is the calculation of corrections for the network internal parameters (weights).Returns the input layer of this neural network.float
Calculates and returns L1 regularization sum of the entire network (all layers included).float
Calculates and returns L2 regularization sum of the entire network (all layers included).getLabel()
Returns the label(name) of this neural networkgetLayerAt
(int idx) Gets layers of this neural network.Returns a loss function of this network, which is used to calculate total network error during the training.getMode()
Returns data normalization method that is applied to network's inputs.Returns network's output.getOutputLabel
(int i) Gets a label of the i-th output this network.String[]
Returns all labels for outputs of this network.Returns the output layer of this network.Gets preprocessing that needs to be performed before input is fed to this network.Returns a training algorithm of this neural network.static <T> T
Loads and returns neural network previously saved to a file.predict
(TensorBase input) Returns the prediction of this neural network for the given input.void
Saves this network using serialization to file with specified fileName.void
setInput
(TensorBase inputs) Sets network input and calculates entire network (triggers forward pass).void
Sets label(name) for this neural network.void
setLossFunction
(LossFunction lossFunction) Sets a loss function of this network, which is used to calculate total network error during the training.void
void
setNormalizer
(AbstractScaler normalizer) Sets normalization data normalization method that is applied to network's inputs.void
setOutputError
(TensorBase outputErrors) Sets the network's output errors, which are a difference between actual(predicted) and target output.void
setOutputLabels
(String... outputLabels) Sets output labels of this network.void
setPreprocessing
(Preprocessing<TensorBase> preprocessing) Sets preprocessing that needs to be performed before input is fed to this network.void
setTrainer
(T trainer) Sets the training algorithm of this neural network.javax.visrec.ml.eval.EvaluationMetrics
test
(javax.visrec.ml.data.DataSet<? extends MLDataItem> testSet) Tests how good are predictions of this network using specified test set.toString()
Returns string representation of this network including all layers and settings.void
train
(javax.visrec.ml.data.DataSet<? extends MLDataItem> trainingSet) Trains the neural network using specified training set.
-
Method Details
-
setInput
Sets network input and calculates entire network (triggers forward pass).- Parameters:
inputs
- input tensor
-
getOutput
-
setOutputError
Sets the network's output errors, which are a difference between actual(predicted) and target output.- Parameters:
outputErrors
- array of errors for each output, a difference between actual(predicted) and target output
-
train
Trains the neural network using specified training set.- Parameters:
trainingSet
- example data given as (input, output) pairs to train the network
-
predict
Returns the prediction of this neural network for the given input. This is the main method to use a trained neural network for inference/prediction. A well trained neural network should provide predictions with low error. Both input and returned prediction are tensors, which are essentially multidimensional arrays.- Parameters:
input
- input for the neural network- Returns:
- network's prediction.
- See Also:
-
test
public javax.visrec.ml.eval.EvaluationMetrics test(javax.visrec.ml.data.DataSet<? extends MLDataItem> testSet) Tests how good are predictions of this network using specified test set. Automatically detects which type of task is network configured to perform and applies appropriate evaluation/test procedure using corresponding Evaluator.- Parameters:
testSet
- data set to test/evaluate predictions- Returns:
- evaluation metrics that show how good this network is at predicting unseen data
- See Also:
-
applyWeightChanges
public void applyWeightChanges()Applies calculated weight changes to all layers. -
backward
public void backward()Performs a backward bass across all layers in neural network, which is the calculation of corrections for the network internal parameters (weights). This method invokes the training steps for all the layers starting from the last/output layer and going backwards to first/input layer. -
getLayers
Gets layers of this neural network.- Returns:
- layers of this neural network
-
getLayerAt
-
getInputLayer
Returns the input layer of this neural network. Input layer is the first layer in network which accepts the external input for the network, and forwards it to the next layer in the network.- Returns:
- input layer of this neural network
- See Also:
-
getOutputLayer
Returns the output layer of this network. Output layer is the last layer of the network which provides final result of the network - predictions.- Returns:
- output layer of this network.
- See Also:
-
setOutputLabels
Sets output labels of this network.- Parameters:
outputLabels
- labels which correspond to outputs of the network.
-
getOutputLabels
Returns all labels for outputs of this network. Each output of the network should have a label which describes what that output represents.- Returns:
- labels for outputs of this network.
-
getOutputLabel
Gets a label of the i-th output this network. Each output of the network should have a label which describes what that output represents.- Parameters:
i
- idx position of the output- Returns:
- label for the i-th output
-
getLossFunction
Returns a loss function of this network, which is used to calculate total network error during the training.- Returns:
- loss function of this network
- See Also:
-
setLossFunction
Sets a loss function of this network, which is used to calculate total network error during the training.- Parameters:
lossFunction
- loss function to use during the training
-
getLabel
Returns the label(name) of this neural network- Returns:
- label of this network
-
setLabel
Sets label(name) for this neural network.- Parameters:
label
- label for this network
-
getL2RegSum
public float getL2RegSum()Calculates and returns L2 regularization sum of the entire network (all layers included). This value is used during the training to prevent over-fitting.- Returns:
- L2 regularization sum
-
getL1RegSum
public float getL1RegSum()Calculates and returns L1 regularization sum of the entire network (all layers included). This value is used during the training to prevent over-fitting.- Returns:
- L2 regularization sum
-
getTrainer
Returns a training algorithm of this neural network. Training algorithm performs tuning of the network's internal parameter(weights) in order to minimize an error.- Specified by:
getTrainer
in interfaceTrainerProvider<T extends Trainer>
- Returns:
- training algorithm of this network
- See Also:
-
setTrainer
Sets the training algorithm of this neural network.- Specified by:
setTrainer
in interfaceTrainerProvider<T extends Trainer>
- Parameters:
trainer
- training algorithm to use for this network- See Also:
-
getNormalizer
Returns data normalization method that is applied to network's inputs.- Returns:
-
setNormalizer
Sets normalization data normalization method that is applied to network's inputs.- Parameters:
normalizer
-
-
toString
-
getPreprocessing
Gets preprocessing that needs to be performed before input is fed to this network.- Returns:
-
setPreprocessing
Sets preprocessing that needs to be performed before input is fed to this network.- Parameters:
preprocessing
-
-
save
Saves this network using serialization to file with specified fileName.- Parameters:
fileName
- name of the file to save network- Throws:
IOException
-
load
public static <T> T load(String fileName, Class<T> clazz) throws IOException, ClassNotFoundException Loads and returns neural network previously saved to a file.- Type Parameters:
T
- type(class) of the network to lead and return.- Parameters:
fileName
- name of the file to load network fromclazz
- class of the neural network to load- Returns:
- loaded neural network
- Throws:
IOException
ClassNotFoundException
-
getMode
-
setMode
-
getThreadPool
-