0
# Model Import
1
2
Complete model import functionality for bringing Keras models with weights into DeepLearning4J format. This includes both Functional API models and Sequential models.
3
4
## Functional API Model Import
5
6
Import complete Keras Functional API models stored in HDF5 format.
7
8
```java { .api }
9
public static ComputationGraph importKerasModelAndWeights(String modelHdf5Filename)
10
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
11
12
public static ComputationGraph importKerasModelAndWeights(String modelHdf5Filename, boolean enforceTrainingConfig)
13
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
14
```
15
16
### Parameters
17
18
- `modelHdf5Filename` (String): Path to HDF5 file containing both model configuration and weights
19
- `enforceTrainingConfig` (boolean): Whether to enforce training-related configurations. When true, unsupported configurations throw exceptions. When false, generates warnings but continues.
20
21
### Returns
22
23
- `ComputationGraph`: DeepLearning4J computation graph with imported weights
24
25
### Exceptions
26
27
- `IOException`: File I/O errors when reading the HDF5 file
28
- `InvalidKerasConfigurationException`: Malformed or invalid Keras model configuration
29
- `UnsupportedKerasConfigurationException`: Keras features not supported by DeepLearning4J
30
31
### Usage Examples
32
33
```java
34
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
35
import org.deeplearning4j.nn.graph.ComputationGraph;
36
37
// Import with default training config enforcement
38
ComputationGraph model = KerasModelImport.importKerasModelAndWeights("my_keras_model.h5");
39
40
// Import with relaxed training config enforcement
41
ComputationGraph model = KerasModelImport.importKerasModelAndWeights("my_keras_model.h5", false);
42
43
// Use the imported model
44
INDArray input = Nd4j.randn(1, 224, 224, 3); // Example input
45
INDArray output = model.outputSingle(input);
46
```
47
48
## Sequential Model Import
49
50
Import complete Keras Sequential models stored in HDF5 format.
51
52
```java { .api }
53
public static MultiLayerNetwork importKerasSequentialModelAndWeights(String modelHdf5Filename)
54
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
55
56
public static MultiLayerNetwork importKerasSequentialModelAndWeights(String modelHdf5Filename, boolean enforceTrainingConfig)
57
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
58
```
59
60
### Parameters
61
62
- `modelHdf5Filename` (String): Path to HDF5 file containing Sequential model configuration and weights
63
- `enforceTrainingConfig` (boolean): Whether to enforce training-related configurations
64
65
### Returns
66
67
- `MultiLayerNetwork`: DeepLearning4J multi-layer network with imported weights
68
69
### Exceptions
70
71
- `IOException`: File I/O errors when reading the HDF5 file
72
- `InvalidKerasConfigurationException`: Malformed or invalid Keras model configuration
73
- `UnsupportedKerasConfigurationException`: Keras features not supported by DeepLearning4J
74
75
### Usage Examples
76
77
```java
78
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
79
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
80
81
// Import Sequential model
82
MultiLayerNetwork model = KerasModelImport.importKerasSequentialModelAndWeights("sequential_model.h5");
83
84
// Use the imported model
85
INDArray input = Nd4j.randn(1, 784); // Example input for MNIST
86
INDArray output = model.output(input);
87
```
88
89
## InputStream Import (Unsupported)
90
91
Import from InputStreams is declared but currently throws `UnsupportedOperationException`.
92
93
```java { .api }
94
public static ComputationGraph importKerasModelAndWeights(InputStream modelHdf5Stream)
95
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
96
97
public static ComputationGraph importKerasModelAndWeights(InputStream modelHdf5Stream, boolean enforceTrainingConfig)
98
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
99
100
public static MultiLayerNetwork importKerasSequentialModelAndWeights(InputStream modelHdf5Stream)
101
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
102
103
public static MultiLayerNetwork importKerasSequentialModelAndWeights(InputStream modelHdf5Stream, boolean enforceTrainingConfig)
104
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
105
```
106
107
These methods currently throw `UnsupportedOperationException` with the message "Reading HDF5 files from InputStreams currently unsupported."
108
109
## Supported Model Formats
110
111
### HDF5 Format Requirements
112
113
The HDF5 file must contain:
114
- Model configuration as JSON string in the `model_config` attribute
115
- Model weights in the `model_weights` group
116
- Optional training configuration in the `training_config` attribute
117
118
### Keras Model Types
119
120
- **Functional API Models**: Complex architectures with multiple inputs/outputs, branching, and merging
121
- **Sequential Models**: Linear stack of layers
122
123
### Supported Keras Versions
124
125
This library is designed to work with Keras models from the TensorFlow/Keras ecosystem, particularly versions that were current around the 0.9.1 release timeframe.
126
127
## Error Handling
128
129
When importing models, several types of errors can occur:
130
131
### File Not Found
132
```java
133
try {
134
ComputationGraph model = KerasModelImport.importKerasModelAndWeights("nonexistent.h5");
135
} catch (IOException e) {
136
System.err.println("Could not read model file: " + e.getMessage());
137
}
138
```
139
140
### Invalid Configuration
141
```java
142
try {
143
ComputationGraph model = KerasModelImport.importKerasModelAndWeights("invalid_model.h5");
144
} catch (InvalidKerasConfigurationException e) {
145
System.err.println("Invalid Keras configuration: " + e.getMessage());
146
}
147
```
148
149
### Unsupported Features
150
```java
151
try {
152
ComputationGraph model = KerasModelImport.importKerasModelAndWeights("unsupported_model.h5", true);
153
} catch (UnsupportedKerasConfigurationException e) {
154
System.err.println("Unsupported Keras feature: " + e.getMessage());
155
// Try with relaxed enforcement
156
ComputationGraph model = KerasModelImport.importKerasModelAndWeights("unsupported_model.h5", false);
157
}
158
```