Keras model import functionality for DeepLearning4J
—
Note: This functionality is deprecated. For new projects, use the deeplearning4j-zoo module which provides a more comprehensive and maintained set of pre-trained models.
Legacy support for popular pre-trained image classification models, specifically VGG16 variants with ImageNet weights.
The TrainedModels enum provides access to pre-trained models with automatic downloading and setup.
public enum TrainedModels {
VGG16, // VGG16 with ImageNet weights and classification head
VGG16NOTOP; // VGG16 with ImageNet weights, no classification head
// Get the complete model as ComputationGraph
public ComputationGraph getComputationGraph() throws IOException;
// Get appropriate preprocessor for the model
public DataSetPreProcessor getDataSetPreProcessor();
// Get ImageNet class labels
public ArrayList<String> getLabels();
// Get expected input shape
public int[] getInputShape();
}Complete VGG16 model with ImageNet weights and classification head.
TrainedModels.VGG16Specifications:
VGG16 model without the final classification layers, useful for feature extraction and transfer learning.
TrainedModels.VGG16NOTOPSpecifications:
import org.deeplearning4j.nn.modelimport.keras.trainedmodels.TrainedModels;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.dataset.api.preprocessor.DataSetPreProcessor;
// Load VGG16 model
ComputationGraph vgg16 = TrainedModels.VGG16.getComputationGraph();
// Get appropriate preprocessor
DataSetPreProcessor preprocessor = TrainedModels.VGG16.getDataSetPreProcessor();
// Load and preprocess image
BufferedImage image = ImageIO.read(new File("image.jpg"));
INDArray imageArray = convertImageToINDArray(image); // Custom conversion method
preprocessor.preProcess(new DataSet(imageArray, null));
// Make prediction
INDArray output = vgg16.outputSingle(imageArray);
// Get predicted class
int predictedClass = Nd4j.argMax(output, 1).getInt(0);
String predictedLabel = TrainedModels.VGG16.getLabels().get(predictedClass);
System.out.println("Predicted: " + predictedLabel);import org.deeplearning4j.nn.modelimport.keras.trainedmodels.TrainedModels;
import org.deeplearning4j.nn.graph.ComputationGraph;
// Load VGG16 without top layers for feature extraction
ComputationGraph featureExtractor = TrainedModels.VGG16NOTOP.getComputationGraph();
// Extract features from image
INDArray imageArray = preprocessedImage; // Your preprocessed image
INDArray features = featureExtractor.outputSingle(imageArray);
// Features shape: [1, 7, 7, 512] - can be flattened for use in other models
INDArray flatFeatures = features.reshape(1, 7 * 7 * 512);// Load feature extractor
ComputationGraph baseModel = TrainedModels.VGG16NOTOP.getComputationGraph();
// Create new model with custom classification head
ComputationGraphConfiguration.GraphBuilder confBuilder = new NeuralNetConfiguration.Builder()
.graphBuilder();
// Add VGG16 layers (frozen)
// ... copy layers from baseModel ...
// Add custom classification layers
confBuilder.addLayer("custom_dense", new DenseLayer.Builder()
.nIn(7 * 7 * 512)
.nOut(128)
.activation(Activation.RELU)
.build(), "vgg16_features");
confBuilder.addLayer("output", new OutputLayer.Builder()
.nIn(128)
.nOut(numCustomClasses)
.activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.build(), "custom_dense");
// Build and initialize custom model
ComputationGraph customModel = new ComputationGraph(confBuilder.build());
customModel.init();Access to ImageNet class labels for interpreting model predictions.
public class ImageNetLabels {
// Get all 1000 ImageNet class labels
public static ArrayList<String> getLabels();
// Get specific class label by index (0-999)
public static String getLabel(int n);
}import org.deeplearning4j.nn.modelimport.keras.trainedmodels.Utils.ImageNetLabels;
// Get all labels
ArrayList<String> allLabels = ImageNetLabels.getLabels();
System.out.println("Total classes: " + allLabels.size());
// Get specific label
String label = ImageNetLabels.getLabel(281); // "tabby, tabby cat"
System.out.println("Class 281: " + label);
// Top-k predictions
INDArray predictions = model.outputSingle(input);
int[] topK = getTopKIndices(predictions, 5); // Custom method to get top-5
System.out.println("Top 5 predictions:");
for (int i = 0; i < topK.length; i++) {
double confidence = predictions.getDouble(topK[i]);
String className = ImageNetLabels.getLabel(topK[i]);
System.out.println((i+1) + ". " + className + " (" + confidence + ")");
}The VGG16 models require specific preprocessing to match training conditions.
// VGG16 preprocessing (from nd4j-dataset-api)
public class VGG16ImagePreProcessor implements DataSetPreProcessor {
// Applies VGG16-specific preprocessing:
// - Converts RGB to BGR
// - Subtracts ImageNet mean values
// - Scales appropriately
}import org.nd4j.linalg.dataset.api.preprocessor.VGG16ImagePreProcessor;
import org.datavec.image.loader.NativeImageLoader;
// Image loading and preprocessing pipeline
NativeImageLoader loader = new NativeImageLoader(224, 224, 3);
VGG16ImagePreProcessor preprocessor = new VGG16ImagePreProcessor();
// Load image
BufferedImage image = ImageIO.read(new File("input.jpg"));
INDArray imageArray = loader.asMatrix(image);
// Apply preprocessing
DataSet ds = new DataSet(imageArray, null);
preprocessor.preProcess(ds);
INDArray processedImage = ds.getFeatures();
// Now ready for VGG16 inference
INDArray prediction = vgg16Model.outputSingle(processedImage);The TrainedModels enum automatically handles:
~/.dl4j/trainedmodels/ directory~/.dl4j/trainedmodels/
├── vgg16/
│ ├── vgg16.json # Model architecture
│ └── vgg16_weights.h5 # Pre-trained weights
└── vgg16notop/
├── vgg16notop.json # Model architecture
└── vgg16notop_weights.h5 # Pre-trained weights// Models are automatically downloaded on first use
// Cache location: System.getProperty("user.home") + "/.dl4j/trainedmodels/"
// To clear cache, delete the directory manually
File cacheDir = new File(System.getProperty("user.home"), ".dl4j/trainedmodels");
if (cacheDir.exists()) {
// Delete cache directory if needed
}Instead of using the deprecated TrainedModels enum, use the deeplearning4j-zoo module:
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-zoo</artifactId>
<version>0.9.1</version>
</dependency>// New approach with DL4J Zoo
import org.deeplearning4j.zoo.model.VGG16;
import org.deeplearning4j.zoo.PretrainedType;
// Load VGG16 from zoo
VGG16 vgg16 = VGG16.builder().build();
ComputationGraph model = (ComputationGraph) vgg16.initPretrained(PretrainedType.IMAGENET);Install with Tessl CLI
npx tessl i tessl/maven-org-deeplearning4j--deeplearning4j-modelimport