public class NeuralNetwork<T extends Trainer> extends Object implements TrainerProvider<T>, Serializable
AbstractLayer
,
LossFunction
,
Serialized FormModifier and Type | Method and Description |
---|---|
void |
applyWeightChanges()
Applies calculated weight changes to all layers.
|
void |
backward()
Performs a backward bass across all layers in neural network, which is the calculation of corrections for the network internal parameters (weights).
|
InputLayer |
getInputLayer()
Returns the input layer of this neural network.
|
float |
getL1RegSum()
Calculates and returns L1 regularization sum of the entire network (all layers included).
|
float |
getL2RegSum()
Calculates and returns L2 regularization sum of the entire network (all layers included).
|
String |
getLabel()
Returns the label(name) of this neural network
|
List<AbstractLayer> |
getLayers()
Gets layers of this neural network.
|
LossFunction |
getLossFunction()
Returns a loss function of this network, which is used to calculate total network error during the training.
|
AbstractScaler |
getNormalizer()
Returns data normalization method that is applied to network's inputs.
|
float[] |
getOutput()
Returns network's output.
|
String |
getOutputLabel(int i)
Gets a label of the i-th output this network.
|
String[] |
getOutputLabels()
Returns all labels for outputs of this network.
|
OutputLayer |
getOutputLayer()
Returns the output layer of this network.
|
Preprocessing<Tensor> |
getPreprocessing()
Gets preprocessing that needs to be performed before input is fed to this network.
|
T |
getTrainer()
Returns a training algorithm of this neural network.
|
static <T> T |
load(String fileName,
Class<T> clazz)
Loads and returns neural network previously saved to a file.
|
Tensor |
predict(Tensor input)
Returns the prediction of this neural network for the given input.
|
void |
save(String fileName)
Saves this network using serialization to file with specified fileName.
|
void |
setInput(Tensor inputs)
Sets network input and calculates entire network (triggers forward pass).
|
void |
setLabel(String label)
Sets label(name) for this neural network.
|
void |
setLossFunction(LossFunction lossFunction)
Sets a loss function of this network, which is used to calculate total network error during the training.
|
void |
setNormalizer(AbstractScaler normalizer)
Sets normalization data normalization method that is applied to network's inputs.
|
void |
setOutputError(float[] outputErrors)
Sets the network's output errors, which are a difference between actual(predicted) and target output.
|
void |
setOutputLabels(String... outputLabels)
Sets output labels of this network.
|
void |
setPreprocessing(Preprocessing<Tensor> preprocessing)
Sets preprocessing that needs to be performed before input is fed to this network.
|
void |
setTrainer(T trainer)
Sets the training algorithm of this neural network.
|
javax.visrec.ml.eval.EvaluationMetrics |
test(javax.visrec.ml.data.DataSet<? extends MLDataItem> testSet)
Tests how good are predictions of this network using specified test set.
|
String |
toString()
Returns string representation of this network including all layers and settings.
|
void |
train(javax.visrec.ml.data.DataSet<? extends MLDataItem> trainingSet)
Trains the neural network using specified training set.
|
public void setInput(Tensor inputs)
inputs
- input tensorpublic float[] getOutput()
public void setOutputError(float[] outputErrors)
outputErrors
- array of errors for each output, a difference between actual(predicted) and target outputpublic void train(javax.visrec.ml.data.DataSet<? extends MLDataItem> trainingSet)
trainingSet
- example data given as (input, output) pairs to train the networkpublic Tensor predict(Tensor input)
input
- input for the neural networkTensor
public javax.visrec.ml.eval.EvaluationMetrics test(javax.visrec.ml.data.DataSet<? extends MLDataItem> testSet)
testSet
- data set to test/evaluate predictionsEvaluationMetrics
,
ClassificationMetrics
,
RegressionMetrics
,
Evaluators
public void applyWeightChanges()
public void backward()
public List<AbstractLayer> getLayers()
public InputLayer getInputLayer()
InputLayer
public OutputLayer getOutputLayer()
OutputLayer
public void setOutputLabels(String... outputLabels)
outputLabels
- labels which correspond to outputs of the network.public String[] getOutputLabels()
public String getOutputLabel(int i)
i
- idx position of the outputpublic LossFunction getLossFunction()
LossFunction
,
LossType
public void setLossFunction(LossFunction lossFunction)
lossFunction
- loss function to use during the trainingpublic String getLabel()
public void setLabel(String label)
label
- label for this networkpublic float getL2RegSum()
public float getL1RegSum()
public T getTrainer()
getTrainer
in interface TrainerProvider<T extends Trainer>
BackpropagationTrainer
public void setTrainer(T trainer)
setTrainer
in interface TrainerProvider<T extends Trainer>
trainer
- training algorithm to use for this networkBackpropagationTrainer
public AbstractScaler getNormalizer()
public void setNormalizer(AbstractScaler normalizer)
normalizer
- public String toString()
public Preprocessing<Tensor> getPreprocessing()
public void setPreprocessing(Preprocessing<Tensor> preprocessing)
preprocessing
- public void save(String fileName) throws IOException
fileName
- name of the file to save networkIOException
public static <T> T load(String fileName, Class<T> clazz) throws IOException, ClassNotFoundException
T
- type(class) of the network to lead and return.fileName
- name of the file to load network fromclazz
- class of the neural network to loadIOException
ClassNotFoundException
Copyright © 2022. All rights reserved.