QUICK SUMMARY: This tutorial walks you through building state of the art AI based on deep learning for image recognition using Java. It will help you to get started with modern AI development using your Java skills. The resulting deep learning model has the capability to recognize handwritten digit in a given image.
The example used in the article is recognition of handwritten digits which is commonly used as a basic hello world example for deep learning.
Start by cloning this GitHub repo: https://github.com/deepnetts/How-to-Get-Started-With-Deep-Learning-in-Java
In order to be able to run example you also need to download and install Deep Netts
Data Set
Our data set is 60,000 example images of handwritten digits. Each image is size of 28×28 pixels. Images will be automatically downloaded and unpacked once you clone and run the example from GitHub.
Few sample images are shown in the image below.
Example images used to train a deep learning model are located in mnist/training folder . This folder contains ten sub-folders named 0–9 where each folder contains images of a specific digit.
In this example we are going to use a subset of randomly chosen 1000 images for simplicity reasons. Folder mnist contains files index.txt which contains a list of subset of images to use for training, and labels.txt which contain image category labels — in this case digits 0–9.
Training a deep learning model in Java code
The model training procedure consist of iteratively presenting example images to a deep learning model (in this case a convolutional neural network) during which it performs an automated tuning of it’s internal parameters in order to lower the output error and increase the recognition accuracy.
This procedure of automated tuning compares image labels from training data to actual outputs / predictions of a neural network to calculate the output error, and uses optimization technique to find the minimal error. Each output of the convolutional neural network corresponds to a single digit, and represents a probability that the input image belongs to the corresponding digit.
Building a convolutional neural network
Typical architecture of a convolutional neural network is shown on the image below.
Architecture of a convolutional neural network consists of a following stack of processing blocks, called layers:
- Input layer accepts external input to the network
- Convolutional layers perform pattern detection
- Pooling layers downsizes its inputs
- Fully connected layers perform classification
- Output layer provides final processing result – a prediction
The number and size of layers depends on the problem, and is usually determined experimentally.
More details about how all this works can be found in this tutorial.
The following code segment creates an instance of convolutional neural network using its builder. More details about all settings used in builder are available in apidocs.
ConvolutionalNetwork neuralNet = ConvolutionalNetwork.builder()
.addInputLayer(imageWidth, imageHeight)
.addConvolutionalLayer(12, 5)
.addMaxPoolingLayer(2, 2)
.addFullyConnectedLayer(60)
.addOutputLayer(labelsCount, ActivationType.SOFTMAX)
.hiddenActivationFunction(ActivationType.RELU)
.lossFunction(LossType.CROSS_ENTROPY)
Train the convolutional neural network
The code segment below sets the basic training parameters and starts the training procedure. More details about various settings for the Backpropagation algorithm used to train convolutional neural networks are available in apidocs.
BackpropagationTrainer trainer = neuralNet.getTrainer();
trainer.setLearningRate(0.001f) // a percent of error used for tuning internal parameters
.setStopError(0.05f) // stop the training when specified error threshold is reached
.setStopEpochs(1000); // stop the training when 1000 iterations/epochs is reached
trainer.train(trainingSet); // run the training with the specified training set
Once the training starts it will log information about every training iteration, also called epoch as shown below. This information includes prediction error and accuracy during the training process.
Downloading and/or unpacking MNIST training set to: D:\DeepNettsProjects\GetStartedWithDeepLearningInJava\mnist - this may take a while ( 44.9 MB )!
Downloaded MNIST data set to mnist
Loading images...
Loaded 10 labels
Loaded 1000 images
Splitting data set: [0.65, 0.35]
Creating neural network architecture...
Training the neural network
------------------------------------------------------------------------
TRAINING NEURAL NETWORK
------------------------------------------------------------------------
Initial Train Error:2.445505
Epoch:1, Time:1893ms, TrainError:1.8640603, TrainErrorChange:-0.5814446, TrainAccuracy: 0.09453033
Epoch:2, Time:1718ms, TrainError:1.1450188, TrainErrorChange:-0.71904147, TrainAccuracy: 0.42761928
Epoch:3, Time:1728ms, TrainError:0.81875384, TrainErrorChange:-0.32626498, TrainAccuracy: 0.6140676
Epoch:4, Time:1820ms, TrainError:0.64733726, TrainErrorChange:-0.17141658, TrainAccuracy: 0.69346887
Epoch:5, Time:1962ms, TrainError:0.5406312, TrainErrorChange:-0.10670608, TrainAccuracy: 0.7378232
...
...
Epoch:42, Time:1379ms, TrainError:0.05284841, TrainErrorChange:-0.0019328743, TrainAccuracy: 0.99852943
Epoch:43, Time:1432ms, TrainError:0.050998032, TrainErrorChange:-0.0018503778, TrainAccuracy: 0.99852943
Epoch:44, Time:1423ms, TrainError:0.049276747, TrainErrorChange:-0.0017212853, TrainAccuracy: 1.0
TRAINING COMPLETED
Total Training Time: 131212ms
The graph below shows how error on network output is lowering while the prediction accuracy is growing during the training.
Testing the trained model
Model testing (or evaluation) is performed in order to check how well the trained model will perform on new data — examples that it has not seen during the training. The model testing is performed in Deep Netts with one call to test() method which returns various classification metrics that help to understand the quality of predictions.
// Test/evaluate trained network to see how it perfroms with unseen data - the test set
EvaluationMetrics em = neuralNet.test(testSet);
The testing/evaluation returns the following results:
Classification metrics
Class: Macro Average
Total items: 368
True positive:298.0 Number of examples correctly classified as positive
True negative:0.0 Number of examples correctly classified as negative
False positive:35.0 Number of examples incorrectly classified as positive
False negative:35.0 Number of examples incorrectly classified as negative
Accuracy (ACC): 0.8097826 How often is classifier correct in total (percent of correct classifications)
Precision (PPV): 0.8948949 How often is classifier correct when it gives positive prediction
Recall: 0.8948949When it is actually positive class, how often does it give positive prediction
F1 Score: 0.8948949 Harmonic average (balance) of precision and recall
False discovery rate (FDR): 0.1051051
Matthews correlation Coefficient (MCC): -0.10510510549197885
The tricky part here is understanding the various evaluation metrics, and Deep Netts helps here by providing concise explanations for each metric.
Using the trained deep learning model in Java
The code below shows how to use a trained model for prediction. First we load the image file into ExampleImage class which provides a format (Tensor) that can be used as input for the predict method of the neural network.
The predict method returns a tensor (a multidimensional array), which contains the probabilities for all digits. The digit with highest probability is most likely the one in the input image.
ExampleImage someImage = new ExampleImage(ImageIO.read(new File("mnist/training/9/00019.png"))); // load some image from file
someImage.invert(); // used in this example/data set in order to focus on black images and not white background
Tensor predictions = neuralNet.predict(someImage.getInput()); // get prediction for the specified image
int maxIdx = indexOfMax(predictions); // get index of prediction with the highest probability
LOGGER.info(predictions);
LOGGER.info("Image label with highest probability:"+neuralNet.getOutputLabel(maxIdx));
Also note that trained model can be saved , in order to load it and use it later in your app.
Building a Java Deep Learning Model Using Visual AI Builder Tool
Deep Netts provides Visual AI Builder that simplifies and accelerates the procedure of building, training and debugging deep learning models using intuitive wizards and visual tools.
Required Resources
In order to run examples from this tutorial you need the following:
- Deep Netts library & Visual AI Builder tool download
- Step-by-step installation tutorial is available in Quick Start tutorial
- Full source code available on GitHub: https://github.com/deepnetts/How-to-Get-Started-With-Deep-Learning-in-Java