Keras model import functionality for DeepLearning4J
—
Advanced model construction using the builder pattern for fine-grained control over the import process. The ModelBuilder provides a fluent interface for configuring all aspects of model import.
The KerasModel.ModelBuilder class provides a fluent API for constructing KerasModel and KerasSequentialModel instances with full control over configuration sources and import behavior.
public static class ModelBuilder implements Cloneable {
public ModelBuilder modelJson(String modelJson);
public ModelBuilder modelJsonFilename(String modelJsonFilename) throws IOException;
public ModelBuilder modelJsonInputStream(InputStream modelJsonInputStream) throws IOException;
public ModelBuilder modelYaml(String modelYaml);
public ModelBuilder modelYamlFilename(String modelYamlFilename) throws IOException;
public ModelBuilder modelYamlInputStream(InputStream modelYamlInputStream) throws IOException;
public ModelBuilder modelHdf5Filename(String modelHdf5Filename)
throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException;
public ModelBuilder weightsHdf5Filename(String weightsHdf5Filename);
public ModelBuilder trainingJson(String trainingJson);
public ModelBuilder trainingJsonInputStream(InputStream trainingJsonInputStream) throws IOException;
public ModelBuilder enforceTrainingConfig(boolean enforceTrainingConfig);
public KerasModel buildModel()
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
public KerasSequentialModel buildSequential()
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
}Set the model architecture from various sources.
// Set model configuration directly as JSON string
public ModelBuilder modelJson(String modelJson);
// Load model configuration from JSON file
public ModelBuilder modelJsonFilename(String modelJsonFilename) throws IOException;
// Load model configuration from InputStream
public ModelBuilder modelJsonInputStream(InputStream modelJsonInputStream) throws IOException;
// Set model configuration as YAML string
public ModelBuilder modelYaml(String modelYaml);
// Load model configuration from YAML file
public ModelBuilder modelYamlFilename(String modelYamlFilename) throws IOException;
// Load model configuration from YAML InputStream
public ModelBuilder modelYamlInputStream(InputStream modelYamlInputStream) throws IOException;Load both configuration and weights from a single HDF5 file.
// Load complete model from HDF5 file (configuration + weights + optional training config)
public ModelBuilder modelHdf5Filename(String modelHdf5Filename)
throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException;Set model weights from HDF5 files.
// Set weights from separate HDF5 file
public ModelBuilder weightsHdf5Filename(String weightsHdf5Filename);Set training-related configuration for complete model import.
// Set training configuration directly as JSON string
public ModelBuilder trainingJson(String trainingJson);
// Load training configuration from InputStream
public ModelBuilder trainingJsonInputStream(InputStream trainingJsonInputStream) throws IOException;Control how the import process handles unsupported configurations.
// Set whether to enforce training configuration compatibility
public ModelBuilder enforceTrainingConfig(boolean enforceTrainingConfig);Create the final model instances from the configured builder.
// Build Functional API model (KerasModel)
public KerasModel buildModel()
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
// Build Sequential model (KerasSequentialModel)
public KerasSequentialModel buildSequential()
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;import org.deeplearning4j.nn.modelimport.keras.KerasModel;
import org.deeplearning4j.nn.graph.ComputationGraph;
// Create builder and configure
KerasModel kerasModel = new KerasModel.ModelBuilder()
.modelJsonFilename("model_config.json")
.weightsHdf5Filename("model_weights.h5")
.enforceTrainingConfig(false)
.buildModel();
// Get DeepLearning4J model
ComputationGraph model = kerasModel.getComputationGraph();import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
// Build Sequential model
KerasSequentialModel kerasModel = new KerasModel.ModelBuilder()
.modelJsonFilename("sequential_config.json")
.weightsHdf5Filename("sequential_weights.h5")
.buildSequential();
// Get DeepLearning4J model
MultiLayerNetwork model = kerasModel.getMultiLayerNetwork();// Import complete model from single HDF5 file
KerasModel kerasModel = new KerasModel.ModelBuilder()
.modelHdf5Filename("complete_model.h5")
.enforceTrainingConfig(true)
.buildModel();
ComputationGraph model = kerasModel.getComputationGraph();// Configure model from JSON string directly
String modelJson = "{'class_name': 'Sequential', 'config': [...]}";
KerasSequentialModel kerasModel = new KerasModel.ModelBuilder()
.modelJson(modelJson)
.buildSequential();
MultiLayerNetwork model = kerasModel.getMultiLayerNetwork();// Include training configuration
KerasModel kerasModel = new KerasModel.ModelBuilder()
.modelJsonFilename("model.json")
.weightsHdf5Filename("weights.h5")
.trainingJson("{'loss': 'categorical_crossentropy', 'optimizer': 'adam'}")
.enforceTrainingConfig(true)
.buildModel();import java.io.FileInputStream;
// Load from InputStreams (useful for resources or network streams)
try (FileInputStream modelStream = new FileInputStream("model.json");
FileInputStream trainingStream = new FileInputStream("training.json")) {
KerasModel kerasModel = new KerasModel.ModelBuilder()
.modelJsonInputStream(modelStream)
.trainingJsonInputStream(trainingStream)
.weightsHdf5Filename("weights.h5")
.buildModel();
ComputationGraph model = kerasModel.getComputationGraph();
}// Use YAML configuration instead of JSON
KerasModel kerasModel = new KerasModel.ModelBuilder()
.modelYamlFilename("model_config.yaml")
.weightsHdf5Filename("weights.h5")
.buildModel();// Build configuration-only model first for validation
KerasModel kerasModel = new KerasModel.ModelBuilder()
.modelJsonFilename("model.json")
.enforceTrainingConfig(true)
.buildModel();
// Validate configuration
ComputationGraphConfiguration config = kerasModel.getComputationGraphConfiguration();
System.out.println("Model has " + config.getVertices().size() + " layers");
// Then load with weights
kerasModel = new KerasModel.ModelBuilder()
.modelJsonFilename("model.json")
.weightsHdf5Filename("weights.h5")
.enforceTrainingConfig(false) // Relax enforcement for weights loading
.buildModel();// Load architecture once, use with different weight files
KerasModel.ModelBuilder baseBuilder = new KerasModel.ModelBuilder()
.modelJsonFilename("shared_architecture.json")
.enforceTrainingConfig(false);
// Model with weights v1
ComputationGraph modelV1 = baseBuilder
.weightsHdf5Filename("weights_v1.h5")
.buildModel()
.getComputationGraph();
// Model with weights v2
ComputationGraph modelV2 = baseBuilder
.weightsHdf5Filename("weights_v2.h5")
.buildModel()
.getComputationGraph();try {
KerasModel kerasModel = new KerasModel.ModelBuilder()
.modelJsonFilename("model.json")
.weightsHdf5Filename("weights.h5")
.enforceTrainingConfig(true)
.buildModel();
ComputationGraph model = kerasModel.getComputationGraph();
} catch (IOException e) {
System.err.println("File I/O error: " + e.getMessage());
} catch (InvalidKerasConfigurationException e) {
System.err.println("Invalid configuration: " + e.getMessage());
} catch (UnsupportedKerasConfigurationException e) {
System.err.println("Unsupported feature: " + e.getMessage());
// Retry with relaxed enforcement
KerasModel kerasModel = new KerasModel.ModelBuilder()
.modelJsonFilename("model.json")
.weightsHdf5Filename("weights.h5")
.enforceTrainingConfig(false)
.buildModel();
}// Simple case - use static method
ComputationGraph simple = KerasModelImport.importKerasModelAndWeights("model.h5");
// Complex case - use builder
KerasModel complex = new KerasModel.ModelBuilder()
.modelYamlFilename("architecture.yaml")
.weightsHdf5Filename("weights.h5")
.trainingJsonInputStream(trainingStream)
.enforceTrainingConfig(false)
.buildModel();The ModelBuilder implements Cloneable for creating copies with shared configuration:
// Create base configuration
KerasModel.ModelBuilder baseBuilder = new KerasModel.ModelBuilder()
.modelJsonFilename("base_architecture.json")
.enforceTrainingConfig(false);
// Clone and customize for different use cases
KerasModel.ModelBuilder trainBuilder = baseBuilder.clone()
.weightsHdf5Filename("initial_weights.h5");
KerasModel.ModelBuilder inferenceBuilder = baseBuilder.clone()
.weightsHdf5Filename("trained_weights.h5");Install with Tessl CLI
npx tessl i tessl/maven-org-deeplearning4j--deeplearning4j-modelimport