A model zoo for deeplearning4j that provides access to pre-trained neural network models
npx @tessl/cli install tessl/maven-org-deeplearning4j--deeplearning4j-zoo@0.9.00
# DeepLearning4j Zoo
1
2
A model zoo for deeplearning4j that provides access to pre-trained neural network models. The library offers popular CNN architectures like AlexNet, VGG, ResNet, and Inception models, as well as RNN models for text generation, with support for downloading and using pre-trained weights from various datasets.
3
4
## Package Information
5
6
- **Package Name**: deeplearning4j-zoo
7
- **Package Type**: maven
8
- **Language**: Java
9
- **Installation**:
10
```xml
11
<dependency>
12
<groupId>org.deeplearning4j</groupId>
13
<artifactId>deeplearning4j-zoo</artifactId>
14
<version>0.9.1</version>
15
</dependency>
16
```
17
18
## Core Imports
19
20
```java
21
import org.deeplearning4j.zoo.*;
22
import org.deeplearning4j.zoo.model.*;
23
import org.deeplearning4j.zoo.model.helper.*;
24
import org.deeplearning4j.zoo.util.imagenet.*;
25
import org.deeplearning4j.nn.api.Model;
26
import org.deeplearning4j.nn.conf.WorkspaceMode;
27
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
28
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
29
import org.deeplearning4j.nn.conf.layers.*;
30
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
31
import org.deeplearning4j.nn.graph.ComputationGraph;
32
import org.nd4j.linalg.api.ndarray.INDArray;
33
import org.nd4j.linalg.activations.Activation;
34
```
35
36
For specific models:
37
38
```java
39
import org.deeplearning4j.zoo.model.AlexNet;
40
import org.deeplearning4j.zoo.model.VGG16;
41
import org.deeplearning4j.zoo.model.ResNet50;
42
import org.deeplearning4j.zoo.model.helper.FaceNetHelper;
43
import org.deeplearning4j.zoo.model.helper.InceptionResNetHelper;
44
import org.deeplearning4j.zoo.util.imagenet.ImageNetLabels;
45
```
46
47
## Basic Usage
48
49
```java
50
import org.deeplearning4j.zoo.*;
51
import org.deeplearning4j.zoo.model.*;
52
import org.deeplearning4j.nn.api.Model;
53
54
// Instantiate a specific model
55
AlexNet alexNet = new AlexNet(1000, 42, 1);
56
Model model = alexNet.init();
57
58
// Load pre-trained weights
59
VGG16 vgg16 = new VGG16(1000, 42, 1);
60
Model pretrainedModel = vgg16.initPretrained(PretrainedType.IMAGENET);
61
62
// Use model selector for multiple models
63
Map<ZooType, ZooModel> models = ModelSelector.select(ZooType.CNN, 10);
64
for (Map.Entry<ZooType, ZooModel> entry : models.entrySet()) {
65
ZooModel zooModel = entry.getValue();
66
Model initializedModel = zooModel.init();
67
}
68
69
// Get model metadata
70
ModelMetaData metadata = alexNet.metaData();
71
int[][] inputShape = metadata.getInputShape();
72
int numOutputs = metadata.getNumOutputs();
73
```
74
75
## Architecture
76
77
The deeplearning4j-zoo is built around several key components:
78
79
- **Model Interface**: `InstantiableModel` defines the contract for all zoo models, ensuring consistent initialization and metadata access
80
- **Base Implementation**: `ZooModel` abstract class provides common functionality including pre-trained model downloading, caching, and checksum verification
81
- **Model Implementations**: Concrete model classes for popular architectures (AlexNet, VGG, ResNet, etc.)
82
- **Model Selection**: `ModelSelector` utility for programmatically selecting and instantiating multiple models
83
- **Type System**: Enumerations (`ZooType`, `PretrainedType`) for categorizing models and pre-trained weight types
84
- **Helper Classes**: Utilities for building complex architectures and decoding predictions
85
86
## Capabilities
87
88
### Core Model Interface
89
90
Foundation interface and abstract class defining how all zoo models are instantiated, configured, and used. Includes pre-trained model support with automatic downloading and caching.
91
92
```java { .api }
93
interface InstantiableModel {
94
void setInputShape(int[][] inputShape);
95
Model init();
96
ModelMetaData metaData();
97
ZooType zooType();
98
Class<? extends Model> modelType();
99
String pretrainedUrl(PretrainedType pretrainedType);
100
long pretrainedChecksum(PretrainedType pretrainedType);
101
}
102
103
abstract class ZooModel<T> implements InstantiableModel {
104
static File ROOT_CACHE_DIR;
105
boolean pretrainedAvailable(PretrainedType pretrainedType);
106
Model initPretrained();
107
Model initPretrained(PretrainedType pretrainedType);
108
}
109
```
110
111
[Core Interface](./core-interface.md)
112
113
### CNN Models
114
115
Collection of convolutional neural network implementations including AlexNet, VGG16/19, ResNet50, GoogLeNet, LeNet, and other popular architectures for image classification and computer vision tasks.
116
117
```java { .api }
118
class AlexNet extends ZooModel {
119
AlexNet(int numLabels, long seed, int iterations);
120
AlexNet(int numLabels, long seed, int iterations, WorkspaceMode workspaceMode);
121
}
122
123
class VGG16 extends ZooModel {
124
VGG16(int numLabels, long seed, int iterations);
125
VGG16(int numLabels, long seed, int iterations, WorkspaceMode workspaceMode);
126
}
127
128
class ResNet50 extends ZooModel {
129
ResNet50(int numLabels, long seed, int iterations);
130
ResNet50(int numLabels, long seed, int iterations, WorkspaceMode workspaceMode);
131
}
132
```
133
134
[CNN Models](./cnn-models.md)
135
136
### RNN Models
137
138
Recurrent neural network implementations including LSTM models for text generation and sequence modeling tasks.
139
140
```java { .api }
141
class TextGenerationLSTM extends ZooModel {
142
TextGenerationLSTM(int numLabels, long seed, int iterations);
143
TextGenerationLSTM(int numLabels, long seed, int iterations, WorkspaceMode workspaceMode);
144
}
145
```
146
147
[RNN Models](./rnn-models.md)
148
149
### Model Selection and Utilities
150
151
Tools for programmatically selecting, instantiating, and working with multiple zoo models, including helper classes for building custom architectures.
152
153
```java { .api }
154
class ModelSelector {
155
static Map<ZooType, ZooModel> select(ZooType zooType);
156
static Map<ZooType, ZooModel> select(ZooType zooType, int numLabels);
157
static Map<ZooType, ZooModel> select(ZooType zooType, int numLabels, WorkspaceMode workspaceMode);
158
static Map<ZooType, ZooModel> select(ZooType zooType, int numLabels, int seed, int iterations, WorkspaceMode workspaceMode);
159
static Map<ZooType, ZooModel> select(WorkspaceMode workspaceMode, ZooType... zooTypes);
160
static Map<ZooType, ZooModel> select(ZooType... zooTypes);
161
static Map<ZooType, ZooModel> select(int numLabels, int seed, int iterations, WorkspaceMode workspaceMode, ZooType... zooTypes);
162
}
163
```
164
165
[Model Selection](./model-selection.md)
166
167
### ImageNet Integration
168
169
Utilities for working with ImageNet classifications, including label decoding and prediction interpretation for models trained on ImageNet dataset.
170
171
```java { .api }
172
class ImageNetLabels {
173
ImageNetLabels();
174
String getLabel(int n);
175
String decodePredictions(INDArray predictions);
176
}
177
```
178
179
[ImageNet Integration](./imagenet-integration.md)
180
181
## Types
182
183
```java { .api }
184
class ModelMetaData {
185
ModelMetaData(int[][] inputShape, int numOutputs, ZooType zooType);
186
int[][] getInputShape();
187
int getNumOutputs();
188
ZooType getZooType();
189
boolean useMDS();
190
}
191
192
enum ZooType {
193
ALL, CNN, SIMPLECNN, ALEXNET, LENET, GOOGLENET, VGG16, VGG19,
194
RESNET50, INCEPTIONRESNETV1, FACENETNN4SMALL2, RNN, TEXTGENLSTM
195
}
196
197
enum PretrainedType {
198
IMAGENET, MNIST, CIFAR10, VGGFACE
199
}
200
201
enum WorkspaceMode {
202
NONE, SINGLE, SEPARATE
203
}
204
205
class MultiLayerConfiguration {
206
// Network configuration for MultiLayerNetwork models
207
}
208
209
class ComputationGraphConfiguration {
210
// Network configuration for ComputationGraph models
211
212
class GraphBuilder {
213
// Builder for constructing computation graphs
214
GraphBuilder addInputs(String... inputs);
215
GraphBuilder addLayer(String layerName, Layer layer, String... inputs);
216
GraphBuilder addVertex(String vertexName, GraphVertex vertex, String... inputs);
217
ComputationGraphConfiguration build();
218
}
219
}
220
221
abstract class Layer {
222
// Base class for all neural network layers
223
}
224
225
class ConvolutionLayer extends Layer {
226
enum AlgoMode {
227
NO_WORKSPACE, PREFER_FASTEST, USER_SPECIFIED
228
}
229
}
230
231
class SubsamplingLayer extends Layer {
232
enum PoolingType {
233
MAX, AVG, SUM, PNORM
234
}
235
}
236
237
class DenseLayer extends Layer {
238
// Fully connected layer
239
}
240
241
class BatchNormalization extends Layer {
242
// Batch normalization layer
243
}
244
245
class ActivationLayer extends Layer {
246
// Activation function layer
247
}
248
249
abstract class GraphVertex {
250
// Base class for graph vertices
251
}
252
253
class MergeVertex extends GraphVertex {
254
// Vertex that merges multiple inputs
255
}
256
257
class ElementWiseVertex extends GraphVertex {
258
// Vertex for element-wise operations
259
}
260
261
abstract class Model {
262
// Base class for all DeepLearning4j models
263
INDArray output(INDArray input);
264
void fit(DataSetIterator iterator);
265
}
266
267
class MultiLayerNetwork extends Model {
268
// Feed-forward neural network implementation
269
}
270
271
class ComputationGraph extends Model {
272
// Computation graph implementation for complex architectures
273
}
274
275
enum Activation {
276
RELU, TANH, SIGMOID, SOFTMAX, IDENTITY, LEAKYRELU
277
}
278
```