0
# DeepLearning4J Model Import
1
2
DeepLearning4J Model Import is a Java library that provides comprehensive functionality to import pre-trained neural network models from Keras into DeepLearning4J's Java ecosystem. It supports importing both Sequential and Functional API models from Keras, including model configurations and trained weights stored in HDF5 format.
3
4
## Package Information
5
6
- **Package Name**: deeplearning4j-modelimport
7
- **Package Type**: Maven
8
- **Group ID**: org.deeplearning4j
9
- **Artifact ID**: deeplearning4j-modelimport
10
- **Language**: Java
11
- **Installation**: `<dependency><groupId>org.deeplearning4j</groupId><artifactId>deeplearning4j-modelimport</artifactId><version>0.9.1</version></dependency>`
12
13
## Core Imports
14
15
```java
16
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
17
import org.deeplearning4j.nn.modelimport.keras.KerasModel;
18
import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel;
19
import org.deeplearning4j.nn.graph.ComputationGraph;
20
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
21
```
22
23
For exceptions:
24
25
```java
26
import org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException;
27
import org.deeplearning4j.nn.modelimport.keras.UnsupportedKerasConfigurationException;
28
```
29
30
## Basic Usage
31
32
```java
33
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
34
import org.deeplearning4j.nn.graph.ComputationGraph;
35
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
36
37
// Import a complete Keras Functional API model (HDF5 format)
38
ComputationGraph functionalModel = KerasModelImport.importKerasModelAndWeights("path/to/model.h5");
39
40
// Import a Keras Sequential model (HDF5 format)
41
MultiLayerNetwork sequentialModel = KerasModelImport.importKerasSequentialModelAndWeights("path/to/sequential.h5");
42
43
// Import from separate JSON configuration and HDF5 weights
44
ComputationGraph separateFiles = KerasModelImport.importKerasModelAndWeights(
45
"path/to/model.json",
46
"path/to/weights.h5"
47
);
48
49
// Import configuration only (no weights)
50
ComputationGraphConfiguration config = KerasModelImport.importKerasModelConfiguration("path/to/model.json");
51
```
52
53
## Architecture
54
55
DeepLearning4J Model Import is built around several key components:
56
57
- **Static Import API**: `KerasModelImport` class provides convenient static methods for common import operations
58
- **Model Classes**: `KerasModel` and `KerasSequentialModel` handle different Keras model types with builder pattern support
59
- **Layer Mapping**: Comprehensive set of layer classes that map Keras layers to DeepLearning4J equivalents
60
- **HDF5 Support**: `Hdf5Archive` class handles reading model data and weights from HDF5 files
61
- **Exception Hierarchy**: Structured exception handling for configuration and compatibility issues
62
- **Pre-trained Models**: Utilities for popular pre-trained models like VGG16 (deprecated, use DL4J Zoo instead)
63
64
## Capabilities
65
66
### Model Import
67
68
Core functionality for importing complete Keras models with weights into DeepLearning4J format. Supports both Functional API and Sequential models.
69
70
```java { .api }
71
// Import complete models
72
public static ComputationGraph importKerasModelAndWeights(String modelHdf5Filename)
73
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
74
75
public static ComputationGraph importKerasModelAndWeights(String modelHdf5Filename, boolean enforceTrainingConfig)
76
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
77
78
public static MultiLayerNetwork importKerasSequentialModelAndWeights(String modelHdf5Filename)
79
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
80
81
public static MultiLayerNetwork importKerasSequentialModelAndWeights(String modelHdf5Filename, boolean enforceTrainingConfig)
82
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
83
```
84
85
[Model Import](./model-import.md)
86
87
### Configuration Import
88
89
Import model configurations without weights, useful for creating model architectures that can be trained separately.
90
91
```java { .api }
92
public static ComputationGraphConfiguration importKerasModelConfiguration(String modelJsonFilename)
93
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
94
95
public static ComputationGraphConfiguration importKerasModelConfiguration(String modelJsonFilename, boolean enforceTrainingConfig)
96
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
97
98
public static MultiLayerConfiguration importKerasSequentialConfiguration(String modelJsonFilename)
99
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
100
101
public static MultiLayerConfiguration importKerasSequentialConfiguration(String modelJsonFilename, boolean enforceTrainingConfig)
102
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
103
```
104
105
[Configuration Import](./configuration-import.md)
106
107
### Separate Files Import
108
109
Import models where configuration (JSON) and weights (HDF5) are stored in separate files, common in Keras workflows.
110
111
```java { .api }
112
public static ComputationGraph importKerasModelAndWeights(String modelJsonFilename, String weightsHdf5Filename)
113
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
114
115
public static ComputationGraph importKerasModelAndWeights(String modelJsonFilename, String weightsHdf5Filename, boolean enforceTrainingConfig)
116
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
117
118
public static MultiLayerNetwork importKerasSequentialModelAndWeights(String modelJsonFilename, String weightsHdf5Filename)
119
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
120
121
public static MultiLayerNetwork importKerasSequentialModelAndWeights(String modelJsonFilename, String weightsHdf5Filename, boolean enforceTrainingConfig)
122
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
123
```
124
125
[Separate Files Import](./separate-files-import.md)
126
127
### Builder Pattern
128
129
Advanced model construction using builder pattern for fine-grained control over the import process.
130
131
```java { .api }
132
public static class ModelBuilder {
133
public ModelBuilder modelJson(String modelJson);
134
public ModelBuilder modelJsonFilename(String modelJsonFilename) throws IOException;
135
public ModelBuilder modelJsonInputStream(InputStream modelJsonInputStream) throws IOException;
136
public ModelBuilder modelYaml(String modelYaml);
137
public ModelBuilder modelYamlFilename(String modelYamlFilename) throws IOException;
138
public ModelBuilder modelYamlInputStream(InputStream modelYamlInputStream) throws IOException;
139
public ModelBuilder modelHdf5Filename(String modelHdf5Filename)
140
throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException;
141
public ModelBuilder weightsHdf5Filename(String weightsHdf5Filename);
142
public ModelBuilder trainingJson(String trainingJson);
143
public ModelBuilder trainingJsonInputStream(InputStream trainingJsonInputStream) throws IOException;
144
public ModelBuilder enforceTrainingConfig(boolean enforceTrainingConfig);
145
public KerasModel buildModel()
146
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
147
public KerasSequentialModel buildSequential()
148
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
149
}
150
```
151
152
[Builder Pattern](./builder-pattern.md)
153
154
### Layer Support
155
156
Comprehensive mapping of Keras layer types to DeepLearning4J layers, including all common neural network components.
157
158
```java { .api }
159
// Core layer types supported:
160
// - Dense (fully connected)
161
// - Convolution1D/2D
162
// - LSTM
163
// - Dropout
164
// - Activation
165
// - BatchNormalization
166
// - Pooling (Max/Average)
167
// - Flatten
168
// - Embedding
169
// - Merge
170
// - Input
171
// And more...
172
```
173
174
[Layer Support](./layer-support.md)
175
176
### Pre-trained Models (Deprecated)
177
178
Legacy support for popular pre-trained models. Use deeplearning4j-zoo module for new projects.
179
180
```java { .api }
181
public enum TrainedModels {
182
VGG16, VGG16NOTOP;
183
184
public ComputationGraph getComputationGraph() throws IOException;
185
public DataSetPreProcessor getDataSetPreProcessor();
186
public ArrayList<String> getLabels();
187
}
188
```
189
190
[Pre-trained Models](./pretrained-models.md)
191
192
## Types
193
194
### Core Model Types
195
196
```java { .api }
197
// Main model types from DeepLearning4J
198
public class ComputationGraph {
199
// Functional API models
200
}
201
202
public class MultiLayerNetwork {
203
// Sequential models
204
}
205
206
public class ComputationGraphConfiguration {
207
// Configuration for Functional API models
208
}
209
210
public class MultiLayerConfiguration {
211
// Configuration for Sequential models
212
}
213
```
214
215
### Configuration Types
216
217
```java { .api }
218
public class KerasModel {
219
public ComputationGraph getComputationGraph()
220
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
221
public ComputationGraph getComputationGraph(boolean importWeights)
222
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
223
public ComputationGraphConfiguration getComputationGraphConfiguration()
224
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
225
}
226
227
public class KerasSequentialModel extends KerasModel {
228
public MultiLayerNetwork getMultiLayerNetwork()
229
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
230
public MultiLayerNetwork getMultiLayerNetwork(boolean importWeights)
231
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
232
public MultiLayerConfiguration getMultiLayerConfiguration()
233
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;
234
}
235
```
236
237
### Exception Types
238
239
```java { .api }
240
public class InvalidKerasConfigurationException extends Exception {
241
public InvalidKerasConfigurationException(String message);
242
public InvalidKerasConfigurationException(String message, Throwable cause);
243
public InvalidKerasConfigurationException(Throwable cause);
244
}
245
246
public class UnsupportedKerasConfigurationException extends Exception {
247
public UnsupportedKerasConfigurationException(String message);
248
public UnsupportedKerasConfigurationException(String message, Throwable cause);
249
public UnsupportedKerasConfigurationException(Throwable cause);
250
}
251
```
252
253
### Utility Types
254
255
```java { .api }
256
public class Hdf5Archive {
257
public Hdf5Archive(String archiveFilename);
258
public INDArray readDataSet(String dataSetName, String groupName);
259
public String readAttributeAsString(String attributeName, String objectName);
260
public String readAttributeAsJson(String attributeName);
261
public List<String> getGroups();
262
public List<String> getGroups(String groupName);
263
public List<String> getDataSets(String groupName);
264
public boolean hasAttribute(String attributeName);
265
}
266
```