Keras model import functionality for DeepLearning4J
—
Import models where configuration (JSON) and weights (HDF5) are stored in separate files. This is a common workflow in Keras where models are saved using model.to_json() and model.save_weights().
Import Keras Functional API models from separate JSON configuration and HDF5 weights files.
public static ComputationGraph importKerasModelAndWeights(String modelJsonFilename, String weightsHdf5Filename)
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
public static ComputationGraph importKerasModelAndWeights(String modelJsonFilename, String weightsHdf5Filename, boolean enforceTrainingConfig)
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;modelJsonFilename (String): Path to JSON file containing model configurationweightsHdf5Filename (String): Path to HDF5 file containing model weightsenforceTrainingConfig (boolean): Whether to enforce training-related configurationsComputationGraph: DeepLearning4J computation graph with imported configuration and weightsIOException: File I/O errors when reading JSON or HDF5 filesInvalidKerasConfigurationException: Malformed or invalid Keras model configurationUnsupportedKerasConfigurationException: Keras features not supported by DeepLearning4Jimport org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
import org.deeplearning4j.nn.graph.ComputationGraph;
// Import from separate files with default enforcement
ComputationGraph model = KerasModelImport.importKerasModelAndWeights(
"model_architecture.json",
"model_weights.h5"
);
// Import with relaxed training config enforcement
ComputationGraph model = KerasModelImport.importKerasModelAndWeights(
"model_architecture.json",
"model_weights.h5",
false
);
// Use the imported model
INDArray input = Nd4j.randn(1, 224, 224, 3);
INDArray output = model.outputSingle(input);Import Keras Sequential models from separate JSON configuration and HDF5 weights files.
public static MultiLayerNetwork importKerasSequentialModelAndWeights(String modelJsonFilename, String weightsHdf5Filename)
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
public static MultiLayerNetwork importKerasSequentialModelAndWeights(String modelJsonFilename, String weightsHdf5Filename, boolean enforceTrainingConfig)
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;modelJsonFilename (String): Path to JSON file containing Sequential model configurationweightsHdf5Filename (String): Path to HDF5 file containing model weightsenforceTrainingConfig (boolean): Whether to enforce training-related configurationsMultiLayerNetwork: DeepLearning4J multi-layer network with imported configuration and weightsIOException: File I/O errors when reading JSON or HDF5 filesInvalidKerasConfigurationException: Malformed or invalid Keras model configurationUnsupportedKerasConfigurationException: Keras features not supported by DeepLearning4Jimport org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
// Import Sequential model from separate files
MultiLayerNetwork model = KerasModelImport.importKerasSequentialModelAndWeights(
"sequential_architecture.json",
"sequential_weights.h5"
);
// Use the imported model
INDArray input = Nd4j.randn(1, 784); // Example for MNIST
INDArray output = model.output(input);This workflow corresponds to the following Python/Keras code:
import keras
import json
# Create or load your model
model = keras.models.Model(inputs=inputs, outputs=outputs)
# Save architecture to JSON
model_json = model.to_json()
with open('model_architecture.json', 'w') as json_file:
json_file.write(model_json)
# Save weights to HDF5
model.save_weights('model_weights.h5')import keras
import json
# Create or load your Sequential model
model = keras.models.Sequential([...])
# Save architecture to JSON
model_json = model.to_json()
with open('sequential_architecture.json', 'w') as json_file:
json_file.write(model_json)
# Save weights to HDF5
model.save_weights('sequential_weights.h5')Must contain valid Keras model configuration in JSON format:
{
"class_name": "Model" | "Sequential",
"config": {
// Model-specific configuration
}
}Must contain model weights in HDF5 format with:
// Step 1: Validate architecture
ComputationGraphConfiguration config = KerasModelImport.importKerasModelConfiguration("model.json");
// Step 2: Load full model with weights
ComputationGraph model = KerasModelImport.importKerasModelAndWeights("model.json", "weights.h5");
// Step 3: Deploy for inference
INDArray prediction = model.outputSingle(inputData);// Load same architecture with different weight versions
String architecture = "resnet50_architecture.json";
// Load v1 weights
ComputationGraph modelV1 = KerasModelImport.importKerasModelAndWeights(architecture, "weights_v1.h5");
// Load v2 weights
ComputationGraph modelV2 = KerasModelImport.importKerasModelAndWeights(architecture, "weights_v2.h5");
// Compare performance or ensemble predictions// Import pre-trained model architecture
ComputationGraphConfiguration baseConfig = KerasModelImport.importKerasModelConfiguration("base_model.json");
// Create model without pre-trained weights for custom training
ComputationGraph customModel = new ComputationGraph(baseConfig);
customModel.init();
// Or load pre-trained weights as starting point
ComputationGraph pretrainedModel = KerasModelImport.importKerasModelAndWeights("base_model.json", "pretrained_weights.h5");import java.nio.file.Paths;
import java.nio.file.Files;
public static boolean validateFiles(String jsonPath, String weightsPath) {
return Files.exists(Paths.get(jsonPath)) && Files.exists(Paths.get(weightsPath));
}
// Usage
if (validateFiles("model.json", "weights.h5")) {
ComputationGraph model = KerasModelImport.importKerasModelAndWeights("model.json", "weights.h5");
} else {
System.err.println("One or both files do not exist");
}try {
// Try with strict enforcement first
ComputationGraph model = KerasModelImport.importKerasModelAndWeights("model.json", "weights.h5", true);
} catch (UnsupportedKerasConfigurationException e) {
System.out.println("Warning: " + e.getMessage());
// Fall back to relaxed enforcement
ComputationGraph model = KerasModelImport.importKerasModelAndWeights("model.json", "weights.h5", false);
}Install with Tessl CLI
npx tessl i tessl/maven-org-deeplearning4j--deeplearning4j-modelimport