A model zoo for deeplearning4j that provides access to pre-trained neural network models
—
Utilities for working with ImageNet classifications, including label decoding and prediction interpretation for models trained on ImageNet dataset.
Utility class for decoding ImageNet predictions to human-readable labels. Downloads and caches ImageNet class labels from a remote JSON file and provides methods for interpreting model predictions.
/**
* Helper class with methods for returning ImageNet label descriptions and
* decoding prediction arrays to human-readable format.
*/
class ImageNetLabels {
/**
* Creates ImageNetLabels instance and loads label data
* Downloads class labels from remote JSON if not already cached
*/
ImageNetLabels();
/**
* Returns the description of the nth class in the 1000 ImageNet classes
* @param n Class index (0-999)
* @return String description of the ImageNet class
*/
String getLabel(int n);
/**
* Decodes prediction array to top 5 matches with probabilities
* Given predictions from trained model, returns formatted string
* listing the top five matches and their respective probabilities
* @param predictions INDArray containing model predictions
* @return Formatted string with top 5 predictions and probabilities
*/
String decodePredictions(INDArray predictions);
}Usage Examples:
// Create ImageNet labels helper
ImageNetLabels imageNetLabels = new ImageNetLabels();
// Get specific class label
String label283 = imageNetLabels.getLabel(283); // "Persian cat"
String label285 = imageNetLabels.getLabel(285); // "Egyptian cat"
// Decode model predictions
VGG16 vgg16 = new VGG16(1000, 42, 1);
Model model = vgg16.initPretrained(PretrainedType.IMAGENET);
// Assuming you have preprocessed image data as INDArray input
INDArray predictions = model.output(imageInput);
// Decode to human-readable format
String topPredictions = imageNetLabels.decodePredictions(predictions);
// Output format:
// Predictions for batch :
// 85.234%, Egyptian cat
// 12.456%, Persian cat
// 1.789%, tabby cat
// 0.345%, tiger cat
// 0.176%, lynxComplete Image Classification Example:
import org.deeplearning4j.zoo.model.VGG16;
import org.deeplearning4j.zoo.util.imagenet.ImageNetLabels;
import org.deeplearning4j.zoo.PretrainedType;
import org.nd4j.linalg.api.ndarray.INDArray;
// 1. Load pre-trained ImageNet model
VGG16 vgg16 = new VGG16(1000, 42, 1);
Model model = vgg16.initPretrained(PretrainedType.IMAGENET);
// 2. Create ImageNet labels decoder
ImageNetLabels imageNetLabels = new ImageNetLabels();
// 3. Preprocess your image (not shown - requires image loading and preprocessing)
// INDArray imageInput = preprocessImage(imagePath);
// 4. Get model predictions
INDArray predictions = model.output(imageInput);
// 5. Decode predictions to readable format
String results = imageNetLabels.decodePredictions(predictions);
System.out.println(results);
// 6. Get individual class labels
int topClass = predictions.argMax(1).getInt(0);
String topClassName = imageNetLabels.getLabel(topClass);
System.out.println("Top prediction: " + topClassName);Batch Processing Example:
ImageNetLabels labels = new ImageNetLabels();
// Process multiple images in batch
INDArray batchPredictions = model.output(batchInput); // Shape: [batchSize, 1000]
// Decode entire batch - shows results for each image in batch
String batchResults = labels.decodePredictions(batchPredictions);
System.out.println(batchResults);
// Output format for batch:
// Predictions for batch 0 :
// 85.234%, Egyptian cat
// ...
// Predictions for batch 1 :
// 67.891%, golden retriever
// ...Label Index Reference:
ImageNetLabels labels = new ImageNetLabels();
// Common ImageNet classes examples:
String label0 = labels.getLabel(0); // "tench"
String label1 = labels.getLabel(1); // "goldfish"
String label151 = labels.getLabel(151); // "Chihuahua"
String label285 = labels.getLabel(285); // "Egyptian cat"
String label945 = labels.getLabel(945); // "bell pepper"
// ImageNet has 1000 classes (indices 0-999)
for (int i = 0; i < 1000; i++) {
String className = labels.getLabel(i);
System.out.println("Class " + i + ": " + className);
}Error Handling:
ImageNetLabels labels = new ImageNetLabels();
try {
// Get label for valid index
String validLabel = labels.getLabel(500); // Returns valid class name
// Note: Invalid indices may cause exceptions
// Always ensure index is within 0-999 range for ImageNet
} catch (Exception e) {
System.err.println("Error accessing ImageNet labels: " + e.getMessage());
}Label Data Source:
The ImageNetLabels class automatically downloads the ImageNet class index from:
http://blob.deeplearning4j.org/utils/imagenet_class_index.jsonIntegration with Zoo Models:
// Works with any ImageNet pre-trained model from the zoo
ImageNetLabels labels = new ImageNetLabels();
// VGG16 with ImageNet weights
VGG16 vgg16 = new VGG16(1000, 42, 1);
Model vgg16Model = vgg16.initPretrained(PretrainedType.IMAGENET);
// AlexNet (if it had ImageNet weights available)
AlexNet alexNet = new AlexNet(1000, 42, 1);
if (alexNet.pretrainedAvailable(PretrainedType.IMAGENET)) {
Model alexNetModel = alexNet.initPretrained(PretrainedType.IMAGENET);
INDArray predictions = alexNetModel.output(input);
String results = labels.decodePredictions(predictions);
}
// ResNet50 with ImageNet weights
ResNet50 resNet50 = new ResNet50(1000, 42, 1);
if (resNet50.pretrainedAvailable(PretrainedType.IMAGENET)) {
Model resNetModel = resNet50.initPretrained(PretrainedType.IMAGENET);
INDArray predictions = resNetModel.output(input);
String results = labels.decodePredictions(predictions);
}Install with Tessl CLI
npx tessl i tessl/maven-org-deeplearning4j--deeplearning4j-zoo