Keras model import functionality for DeepLearning4J
—
Complete model import functionality for bringing Keras models with weights into DeepLearning4J format. This includes both Functional API models and Sequential models.
Import complete Keras Functional API models stored in HDF5 format.
public static ComputationGraph importKerasModelAndWeights(String modelHdf5Filename)
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
public static ComputationGraph importKerasModelAndWeights(String modelHdf5Filename, boolean enforceTrainingConfig)
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;modelHdf5Filename (String): Path to HDF5 file containing both model configuration and weightsenforceTrainingConfig (boolean): Whether to enforce training-related configurations. When true, unsupported configurations throw exceptions. When false, generates warnings but continues.ComputationGraph: DeepLearning4J computation graph with imported weightsIOException: File I/O errors when reading the HDF5 fileInvalidKerasConfigurationException: Malformed or invalid Keras model configurationUnsupportedKerasConfigurationException: Keras features not supported by DeepLearning4Jimport org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
import org.deeplearning4j.nn.graph.ComputationGraph;
// Import with default training config enforcement
ComputationGraph model = KerasModelImport.importKerasModelAndWeights("my_keras_model.h5");
// Import with relaxed training config enforcement
ComputationGraph model = KerasModelImport.importKerasModelAndWeights("my_keras_model.h5", false);
// Use the imported model
INDArray input = Nd4j.randn(1, 224, 224, 3); // Example input
INDArray output = model.outputSingle(input);Import complete Keras Sequential models stored in HDF5 format.
public static MultiLayerNetwork importKerasSequentialModelAndWeights(String modelHdf5Filename)
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
public static MultiLayerNetwork importKerasSequentialModelAndWeights(String modelHdf5Filename, boolean enforceTrainingConfig)
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;modelHdf5Filename (String): Path to HDF5 file containing Sequential model configuration and weightsenforceTrainingConfig (boolean): Whether to enforce training-related configurationsMultiLayerNetwork: DeepLearning4J multi-layer network with imported weightsIOException: File I/O errors when reading the HDF5 fileInvalidKerasConfigurationException: Malformed or invalid Keras model configurationUnsupportedKerasConfigurationException: Keras features not supported by DeepLearning4Jimport org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
// Import Sequential model
MultiLayerNetwork model = KerasModelImport.importKerasSequentialModelAndWeights("sequential_model.h5");
// Use the imported model
INDArray input = Nd4j.randn(1, 784); // Example input for MNIST
INDArray output = model.output(input);Import from InputStreams is declared but currently throws UnsupportedOperationException.
public static ComputationGraph importKerasModelAndWeights(InputStream modelHdf5Stream)
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
public static ComputationGraph importKerasModelAndWeights(InputStream modelHdf5Stream, boolean enforceTrainingConfig)
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
public static MultiLayerNetwork importKerasSequentialModelAndWeights(InputStream modelHdf5Stream)
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
public static MultiLayerNetwork importKerasSequentialModelAndWeights(InputStream modelHdf5Stream, boolean enforceTrainingConfig)
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;These methods currently throw UnsupportedOperationException with the message "Reading HDF5 files from InputStreams currently unsupported."
The HDF5 file must contain:
model_config attributemodel_weights grouptraining_config attributeThis library is designed to work with Keras models from the TensorFlow/Keras ecosystem, particularly versions that were current around the 0.9.1 release timeframe.
When importing models, several types of errors can occur:
try {
ComputationGraph model = KerasModelImport.importKerasModelAndWeights("nonexistent.h5");
} catch (IOException e) {
System.err.println("Could not read model file: " + e.getMessage());
}try {
ComputationGraph model = KerasModelImport.importKerasModelAndWeights("invalid_model.h5");
} catch (InvalidKerasConfigurationException e) {
System.err.println("Invalid Keras configuration: " + e.getMessage());
}try {
ComputationGraph model = KerasModelImport.importKerasModelAndWeights("unsupported_model.h5", true);
} catch (UnsupportedKerasConfigurationException e) {
System.err.println("Unsupported Keras feature: " + e.getMessage());
// Try with relaxed enforcement
ComputationGraph model = KerasModelImport.importKerasModelAndWeights("unsupported_model.h5", false);
}Install with Tessl CLI
npx tessl i tessl/maven-org-deeplearning4j--deeplearning4j-modelimport