Keras model import functionality for DeepLearning4J
—
Import model configurations without weights, useful for creating model architectures that can be trained separately or loaded with different weights.
Import configuration for Keras Functional API models from JSON files.
public static ComputationGraphConfiguration importKerasModelConfiguration(String modelJsonFilename)
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
public static ComputationGraphConfiguration importKerasModelConfiguration(String modelJsonFilename, boolean enforceTrainingConfig)
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;modelJsonFilename (String): Path to JSON file containing Keras model configurationenforceTrainingConfig (boolean): Whether to enforce training-related configurationsComputationGraphConfiguration: DeepLearning4J computation graph configurationIOException: File I/O errors when reading the JSON fileInvalidKerasConfigurationException: Malformed or invalid Keras model configurationUnsupportedKerasConfigurationException: Keras features not supported by DeepLearning4Jimport org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
// Import configuration only
ComputationGraphConfiguration config = KerasModelImport.importKerasModelConfiguration("model_config.json");
// Create model from configuration (without pre-trained weights)
ComputationGraph model = new ComputationGraph(config);
model.init();
// Now you can train the model or load weights separatelyImport configuration for Keras Sequential models from JSON files.
public static MultiLayerConfiguration importKerasSequentialConfiguration(String modelJsonFilename)
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
public static MultiLayerConfiguration importKerasSequentialConfiguration(String modelJsonFilename, boolean enforceTrainingConfig)
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;modelJsonFilename (String): Path to JSON file containing Keras Sequential model configurationenforceTrainingConfig (boolean): Whether to enforce training-related configurationsMultiLayerConfiguration: DeepLearning4J multi-layer network configurationIOException: File I/O errors when reading the JSON fileInvalidKerasConfigurationException: Malformed or invalid Keras model configurationUnsupportedKerasConfigurationException: Keras features not supported by DeepLearning4Jimport org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
// Import Sequential configuration only
MultiLayerConfiguration config = KerasModelImport.importKerasSequentialConfiguration("sequential_config.json");
// Create model from configuration
MultiLayerNetwork model = new MultiLayerNetwork(config);
model.init();
// Train or load weights as neededKeras model configurations can be saved to JSON using Keras/TensorFlow:
# In Python with Keras
import keras
# Load or create your model
model = keras.models.load_model('my_model.h5')
# Save configuration to JSON
model_json = model.to_json()
with open('model_config.json', 'w') as f:
f.write(model_json)
# For Sequential models
if hasattr(model, 'get_config'):
config = model.get_config()
with open('sequential_config.json', 'w') as f:
json.dump({'class_name': 'Sequential', 'config': config}, f)// Import architecture from pre-trained model
ComputationGraphConfiguration config = KerasModelImport.importKerasModelConfiguration("pretrained_config.json");
// Create new model with same architecture
ComputationGraph model = new ComputationGraph(config);
model.init();
// Train on your specific dataset
// ... training code ...// Import and examine model configuration
ComputationGraphConfiguration config = KerasModelImport.importKerasModelConfiguration("model_config.json");
// Access configuration details
System.out.println("Number of layers: " + config.getVertices().size());
System.out.println("Input types: " + Arrays.toString(config.getNetworkInputTypes()));
System.out.println("Output names: " + Arrays.toString(config.getNetworkOutputs()));// Import configuration
ComputationGraphConfiguration config = KerasModelImport.importKerasModelConfiguration("model_config.json");
// Create model
ComputationGraph model = new ComputationGraph(config);
model.init();
// Load custom weights (not from Keras)
// ... custom weight loading logic ...The JSON configuration file should follow Keras model serialization format:
{
"class_name": "Model",
"config": {
"name": "model_name",
"layers": [...],
"input_layers": [...],
"output_layers": [...]
}
}{
"class_name": "Sequential",
"config": [
{
"class_name": "Dense",
"config": {...}
},
...
]
}When enforceTrainingConfig is set to true:
When enforceTrainingConfig is set to false:
Install with Tessl CLI
npx tessl i tessl/maven-org-deeplearning4j--deeplearning4j-modelimport