Keras model import functionality for DeepLearning4J
npx @tessl/cli install tessl/maven-org-deeplearning4j--deeplearning4j-modelimport@0.9.0DeepLearning4J Model Import is a Java library that provides comprehensive functionality to import pre-trained neural network models from Keras into DeepLearning4J's Java ecosystem. It supports importing both Sequential and Functional API models from Keras, including model configurations and trained weights stored in HDF5 format.
<dependency><groupId>org.deeplearning4j</groupId><artifactId>deeplearning4j-modelimport</artifactId><version>0.9.1</version></dependency>import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
import org.deeplearning4j.nn.modelimport.keras.KerasModel;
import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;For exceptions:
import org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.UnsupportedKerasConfigurationException;import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
// Import a complete Keras Functional API model (HDF5 format)
ComputationGraph functionalModel = KerasModelImport.importKerasModelAndWeights("path/to/model.h5");
// Import a Keras Sequential model (HDF5 format)
MultiLayerNetwork sequentialModel = KerasModelImport.importKerasSequentialModelAndWeights("path/to/sequential.h5");
// Import from separate JSON configuration and HDF5 weights
ComputationGraph separateFiles = KerasModelImport.importKerasModelAndWeights(
"path/to/model.json",
"path/to/weights.h5"
);
// Import configuration only (no weights)
ComputationGraphConfiguration config = KerasModelImport.importKerasModelConfiguration("path/to/model.json");DeepLearning4J Model Import is built around several key components:
KerasModelImport class provides convenient static methods for common import operationsKerasModel and KerasSequentialModel handle different Keras model types with builder pattern supportHdf5Archive class handles reading model data and weights from HDF5 filesCore functionality for importing complete Keras models with weights into DeepLearning4J format. Supports both Functional API and Sequential models.
// Import complete models
public static ComputationGraph importKerasModelAndWeights(String modelHdf5Filename)
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
public static ComputationGraph importKerasModelAndWeights(String modelHdf5Filename, boolean enforceTrainingConfig)
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
public static MultiLayerNetwork importKerasSequentialModelAndWeights(String modelHdf5Filename)
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
public static MultiLayerNetwork importKerasSequentialModelAndWeights(String modelHdf5Filename, boolean enforceTrainingConfig)
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;Import model configurations without weights, useful for creating model architectures that can be trained separately.
public static ComputationGraphConfiguration importKerasModelConfiguration(String modelJsonFilename)
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
public static ComputationGraphConfiguration importKerasModelConfiguration(String modelJsonFilename, boolean enforceTrainingConfig)
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
public static MultiLayerConfiguration importKerasSequentialConfiguration(String modelJsonFilename)
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
public static MultiLayerConfiguration importKerasSequentialConfiguration(String modelJsonFilename, boolean enforceTrainingConfig)
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;Import models where configuration (JSON) and weights (HDF5) are stored in separate files, common in Keras workflows.
public static ComputationGraph importKerasModelAndWeights(String modelJsonFilename, String weightsHdf5Filename)
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
public static ComputationGraph importKerasModelAndWeights(String modelJsonFilename, String weightsHdf5Filename, boolean enforceTrainingConfig)
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
public static MultiLayerNetwork importKerasSequentialModelAndWeights(String modelJsonFilename, String weightsHdf5Filename)
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
public static MultiLayerNetwork importKerasSequentialModelAndWeights(String modelJsonFilename, String weightsHdf5Filename, boolean enforceTrainingConfig)
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;Advanced model construction using builder pattern for fine-grained control over the import process.
public static class ModelBuilder {
public ModelBuilder modelJson(String modelJson);
public ModelBuilder modelJsonFilename(String modelJsonFilename) throws IOException;
public ModelBuilder modelJsonInputStream(InputStream modelJsonInputStream) throws IOException;
public ModelBuilder modelYaml(String modelYaml);
public ModelBuilder modelYamlFilename(String modelYamlFilename) throws IOException;
public ModelBuilder modelYamlInputStream(InputStream modelYamlInputStream) throws IOException;
public ModelBuilder modelHdf5Filename(String modelHdf5Filename)
throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException;
public ModelBuilder weightsHdf5Filename(String weightsHdf5Filename);
public ModelBuilder trainingJson(String trainingJson);
public ModelBuilder trainingJsonInputStream(InputStream trainingJsonInputStream) throws IOException;
public ModelBuilder enforceTrainingConfig(boolean enforceTrainingConfig);
public KerasModel buildModel()
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
public KerasSequentialModel buildSequential()
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
}Comprehensive mapping of Keras layer types to DeepLearning4J layers, including all common neural network components.
// Core layer types supported:
// - Dense (fully connected)
// - Convolution1D/2D
// - LSTM
// - Dropout
// - Activation
// - BatchNormalization
// - Pooling (Max/Average)
// - Flatten
// - Embedding
// - Merge
// - Input
// And more...Legacy support for popular pre-trained models. Use deeplearning4j-zoo module for new projects.
public enum TrainedModels {
VGG16, VGG16NOTOP;
public ComputationGraph getComputationGraph() throws IOException;
public DataSetPreProcessor getDataSetPreProcessor();
public ArrayList<String> getLabels();
}// Main model types from DeepLearning4J
public class ComputationGraph {
// Functional API models
}
public class MultiLayerNetwork {
// Sequential models
}
public class ComputationGraphConfiguration {
// Configuration for Functional API models
}
public class MultiLayerConfiguration {
// Configuration for Sequential models
}public class KerasModel {
public ComputationGraph getComputationGraph()
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
public ComputationGraph getComputationGraph(boolean importWeights)
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
public ComputationGraphConfiguration getComputationGraphConfiguration()
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
}
public class KerasSequentialModel extends KerasModel {
public MultiLayerNetwork getMultiLayerNetwork()
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
public MultiLayerNetwork getMultiLayerNetwork(boolean importWeights)
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
public MultiLayerConfiguration getMultiLayerConfiguration()
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
}public class InvalidKerasConfigurationException extends Exception {
public InvalidKerasConfigurationException(String message);
public InvalidKerasConfigurationException(String message, Throwable cause);
public InvalidKerasConfigurationException(Throwable cause);
}
public class UnsupportedKerasConfigurationException extends Exception {
public UnsupportedKerasConfigurationException(String message);
public UnsupportedKerasConfigurationException(String message, Throwable cause);
public UnsupportedKerasConfigurationException(Throwable cause);
}public class Hdf5Archive {
public Hdf5Archive(String archiveFilename);
public INDArray readDataSet(String dataSetName, String groupName);
public String readAttributeAsString(String attributeName, String objectName);
public String readAttributeAsJson(String attributeName);
public List<String> getGroups();
public List<String> getGroups(String groupName);
public List<String> getDataSets(String groupName);
public boolean hasAttribute(String attributeName);
}